未验证 提交 799f3861 编写于 作者: A Adam Osewski 提交者: GitHub

Reuse OneDNN handler for SGD and SUM for SelectedRows input tensors. (#35510)

* Create stateful OneDNNAXPYHandler object.

This makes it possible to call it multiple times without recreating the
oneDNN primitives every time.

* Prepare SGDOpKernel to reuse its implementation from OneDNN kernel.

* OneDNN SGD kernel.

* Update call to use new OneDNNAXPYHandler object api.

* Setup seed in proper place.

* Enable OneDNN kernel only for single case.

* For dense param and sparse grad.

* Small refactor.

* Enable oneDNN by op attr or by cmd line flag.

* Use int64_t type for number of elements.

* Support dense param and grad from OneDNN kernel.

* Enable SGD OneDNN kernel when use MP BF16 optimizer.

* Force non-copyable/movable OneDNNAXPYHandler.

* Reuse OneDNNAXPYHandler for spare tensors in SUM op.

* Fix SFINAE rules.

* Remove recording event inside AXPY.

* Get rid of internal primitive caching.

* Stop use PP cache mechanims to store mem and primitive obj.
* Handler obj store and reuse needed desc & prim

* Do not derive from MKLDNNHandlerT
上级 86685190
......@@ -301,23 +301,9 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
namespace scatter {
template <typename T>
typename std::enable_if<std::is_same<T, platform::bfloat16>::value>::type
elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) {
#ifdef PADDLE_WITH_MKLDNN
onednn_handler_axpy(data_len, T(1.f), in, out);
#else
blas->AXPY(data_len, T(1.f), in, out);
#endif
}
template <typename T>
typename std::enable_if<std::is_same<T, float>::value ||
std::is_same<T, double>::value ||
std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type
elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) {
typename std::enable_if<!std::is_integral<T>::value>::type elementwise_add_to(
BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len, const T* in,
T* out) {
blas->AXPY(data_len, T(1.f), in, out);
}
......@@ -330,6 +316,64 @@ typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to(
}
}
template <typename T>
typename std::enable_if<std::is_same<T, platform::bfloat16>::value>::type
add_sparse_inputs(const std::vector<const framework::SelectedRows*>& inputs,
const std::unordered_map<int64_t, size_t>& rows_to_id,
int64_t input_width,
const platform::CPUDeviceContext& context, T* out_data) {
#ifndef PADDLE_WITH_MKLDNN
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
#endif
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
}
auto* input_data = input->value().data<T>();
auto& input_rows = input->rows();
#ifdef PADDLE_WITH_MKLDNN
OneDNNAXPYHandler<T> axpy_handler(input_width, T(1.f));
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id.at(input_rows[i]);
axpy_handler(&input_data[i * input_width],
&out_data[out_i * input_width]);
}
#else
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id.at(input_rows[i]);
elementwise_add_to<T>(&blas, static_cast<size_t>(input_width),
&input_data[i * input_width],
&out_data[out_i * input_width]);
}
#endif
}
}
template <typename T>
typename std::enable_if<!std::is_same<T, platform::bfloat16>::value>::type
add_sparse_inputs(const std::vector<const framework::SelectedRows*>& inputs,
const std::unordered_map<int64_t, size_t>& rows_to_id,
int64_t input_width,
const platform::CPUDeviceContext& context, T* out_data) {
VLOG(4) << "[CPU] add_sparse_inputs <" << typeid(T).name();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
}
auto* input_data = input->value().data<T>();
auto& input_rows = input->rows();
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id.at(input_rows[i]);
elementwise_add_to<T>(&blas, static_cast<size_t>(input_width),
&input_data[i * input_width],
&out_data[out_i * input_width]);
}
}
}
template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
......@@ -435,21 +479,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
rows_to_id[merge_rows[i]] = i;
}
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
}
auto* input_data = input->value().data<T>();
auto& input_rows = input->rows();
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
elementwise_add_to<T>(&blas, static_cast<size_t>(input_width),
&input_data[i * input_width],
&out_data[out_i * input_width]);
}
}
add_sparse_inputs<T>(inputs, rows_to_id, input_width, context, out_data);
}
}
};
......
......@@ -22,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -34,76 +33,46 @@ namespace plat = paddle::platform;
namespace {
template <typename T>
class AXPYMKLDNNHandler : public plat::MKLDNNHandlerT<T, dnnl::reorder> {
class AXPYHandler {
public:
AXPYMKLDNNHandler(const plat::MKLDNNDeviceContext &dev_ctx,
const dnnl::engine mkldnn_engine, plat::Place cpu_place,
int n, float alpha)
: plat::MKLDNNHandlerT<T, dnnl::reorder>(
dev_ctx, mkldnn_engine, cpu_place,
plat::CreateKey(dev_ctx, static_cast<int64_t>(n),
plat::MKLDNNGetDataType<T>(), alpha, "-axpy")),
alpha_(alpha),
n_(n) {}
std::shared_ptr<dnnl::memory> AcquireMemory(void *ptr,
const std::string &suffix) {
/*Generate key*/
auto local_key = this->key_ + suffix;
auto mem_p = std::static_pointer_cast<dnnl::memory>(
this->dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto md = dnnl::memory::desc({n_}, plat::MKLDNNGetDataType<T>(),
AXPYHandler(const dnnl::engine mkldnn_engine, int n, float alpha) {
platform::MKLDNNDeviceContext::tls().log_lib_version();
auto md = dnnl::memory::desc({n}, plat::MKLDNNGetDataType<T>(),
dnnl::memory::format_tag::x);
mem_p = std::make_shared<dnnl::memory>(md, this->engine_, ptr);
this->dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const T *x) {
return this->AcquireMemory(plat::to_void_cast(x), "@user_src_mem_p");
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(T *y) {
return this->AcquireMemory(y, "@user_dst_mem_p");
}
std::shared_ptr<dnnl::reorder> AcquireReorder(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
auto prim_key = this->key_ + "@reorder_p";
auto reorder_p = std::static_pointer_cast<dnnl::reorder>(
this->dev_ctx_.GetBlob(prim_key));
if (reorder_p == nullptr) {
// Here we pass Postops to mimick y -> a*X + y
src_mem_ = dnnl::memory(md, mkldnn_engine, DNNL_MEMORY_NONE);
dst_mem_ = dnnl::memory(md, mkldnn_engine, DNNL_MEMORY_NONE);
dnnl::primitive_attr reorder_attr;
dnnl::post_ops post_operations;
if (this->alpha_ != 1.f) {
std::vector<float> scales(1, this->alpha_);
if (alpha != 1.f) {
std::vector<float> scales(1, alpha);
reorder_attr.set_output_scales(0, scales);
}
post_operations.append_sum(1.0f);
reorder_attr.set_post_ops(post_operations);
reorder_p = std::make_shared<dnnl::reorder>(
*(src_memory_p), *(dst_memory_p), reorder_attr);
this->dev_ctx_.SetBlob(prim_key, reorder_p);
reorder_p_ = dnnl::reorder(src_mem_, dst_mem_, reorder_attr);
}
return reorder_p;
dnnl::memory &AcquireSrcMemory(const T *x) {
src_mem_.set_data_handle(plat::to_void_cast<T>(x));
return src_mem_;
}
dnnl::memory &AcquireDstMemory(T *y) {
dst_mem_.set_data_handle(y);
return dst_mem_;
}
const dnnl::reorder &AcquireReorder() { return reorder_p_; }
private:
float alpha_;
int n_;
dnnl::memory src_mem_;
dnnl::memory dst_mem_;
dnnl::reorder reorder_p_;
};
template class AXPYMKLDNNHandler<float>;
template class AXPYMKLDNNHandler<plat::bfloat16>;
} // anonnymouse namespace
template class AXPYHandler<float>;
template class AXPYHandler<plat::bfloat16>;
template <typename T>
static void naive_axpy(int n, T alpha, const T *x, T *y) {
......@@ -114,39 +83,60 @@ static void naive_axpy(int n, T alpha, const T *x, T *y) {
}
}
} // anonnymouse namespace
template <typename T>
void onednn_handler_axpy(int n, T alpha, const T *x, T *y) {
// fallback to naive version
if (n < 100) {
naive_axpy(n, alpha, x, y);
return;
}
class OneDNNAXPYHandler<T>::Impl {
public:
Impl(int64_t n, T alpha);
void operator()(const T *x, T *y);
private:
std::unique_ptr<AXPYHandler<T>> handler_;
int64_t n_;
T alpha_;
};
template <typename T>
OneDNNAXPYHandler<T>::Impl::Impl(int64_t n, T alpha) : n_{n}, alpha_{alpha} {
auto &pool = plat::DeviceContextPool::Instance();
auto cpu_place = plat::CPUPlace();
auto *dev_ctx =
dynamic_cast<plat::MKLDNNDeviceContext *>(pool.Get(cpu_place));
auto &cpu_engine = dev_ctx->GetEngine();
AXPYMKLDNNHandler<T> handler(*dev_ctx, cpu_engine, cpu_place, n,
handler_ = std::make_unique<AXPYHandler<T>>(cpu_engine, n,
static_cast<float>(alpha));
}
auto reorder_src_memory_p = handler.AcquireSrcMemory(x);
auto reorder_dst_memory_p = handler.AcquireDstMemory(y);
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
template <typename T>
void OneDNNAXPYHandler<T>::Impl::operator()(const T *x, T *y) {
if (this->n_ < 100) {
naive_axpy(this->n_, this->alpha_, x, y);
return;
}
auto &reorder_src_mem_p = handler_->AcquireSrcMemory(x);
auto &reorder_dst_mem_p = handler_->AcquireDstMemory(y);
auto reorder_p = handler_->AcquireReorder();
auto &astream = plat::MKLDNNDeviceContext::tls().get_stream();
plat::RecordEvent record_reorder("axpy_int_reorder",
plat::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
reorder_p.execute(astream, reorder_src_mem_p, reorder_dst_mem_p);
astream.wait();
}
template void onednn_handler_axpy<float>(int, float, const float *, float *);
template void onednn_handler_axpy<plat::bfloat16>(int, plat::bfloat16,
const plat::bfloat16 *,
plat::bfloat16 *);
template <typename T>
OneDNNAXPYHandler<T>::OneDNNAXPYHandler(int64_t n, T alpha)
: pimpl_{new Impl{n, alpha}, [](Impl *impl) { delete impl; }} {
VLOG(4) << "[OneDNN] OneDNNAXPYHandler<" << typeid(T).name() << ">, "
<< "n: " << n << ", alpha: " << alpha;
}
template <typename T>
void OneDNNAXPYHandler<T>::operator()(const T *x, T *y) {
pimpl_->operator()(x, y);
}
template class OneDNNAXPYHandler<float>;
template class OneDNNAXPYHandler<plat::bfloat16>;
} // namespace operators
} // namespace paddle
......@@ -13,21 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
namespace paddle {
namespace operators {
///
/// @brief Helper function to execute AXPY using oneDNN.
///
/// @param[in] n The number of elements in tensor (assumed 1D)
/// @param[in] alpha The alpha coefficient.
/// @param[in] x The pointer to input X tensor.
/// @param y The pointer to output Y tensor.
/// @brief Helper class for AXPY execution using oneDNN library.
///
/// @tparam T Data type.
///
template <typename T>
void onednn_handler_axpy(int n, T alpha, const T *x, T *y);
class OneDNNAXPYHandler {
public:
OneDNNAXPYHandler(OneDNNAXPYHandler&) = delete;
OneDNNAXPYHandler(OneDNNAXPYHandler&&) = delete;
OneDNNAXPYHandler& operator=(OneDNNAXPYHandler&) = delete;
OneDNNAXPYHandler& operator=(OneDNNAXPYHandler&&) = delete;
///
/// @brief Constructor.
///
/// @param[in] n The number of elements in tensor (assumed 1D tensor)
/// @param[in] alpha The alpha coefficient.
///
OneDNNAXPYHandler(int64_t n, T alpha);
///
/// @brief Executes AXPY.
///
/// @param[in] x The pointer to input X tensor data.
/// @param[out] y The pointer to output Y tensor data.
///
void operator()(const T* x, T* y);
private:
OneDNNAXPYHandler() = delete;
// (arogowie-intel) Private implementation idiom to hide dependency
// on OneDNN headers.
class Impl;
// We need custom deleter, since the compiler is unable to parameterize
// an allocator's default deleter due to incomple type.
std::unique_ptr<Impl, void (*)(Impl*)> pimpl_;
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstring>
#include "paddle/fluid/operators/mkldnn/axpy_handler.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h"
namespace pplat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
class SGDOneDNNKernel : public SGDOpKernel<pplat::CPUDeviceContext, T> {
protected:
void dense_param_and_grad_kernel(
const framework::ExecutionContext &ctx) const override {
VLOG(4) << "[ONEDNN]: sgd_dense_param_kernel<T, LodTensor>";
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *param = ctx.Input<framework::Tensor>("Param");
auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
const auto *grad = ctx.Input<framework::Tensor>("Grad");
auto *out_data = param_out->mutable_data<T>(ctx.GetPlace());
const T *param_data = param->data<T>();
const auto *grad_data = grad->data<T>();
const auto *lr = learning_rate->data<T>();
// Since denese SGD is not in place operation, first copy params to output
// tensor and then update it.
std::memcpy(out_data, param_data, param->memory_size());
OneDNNAXPYHandler<T>(param_out->numel(), -lr[0])(grad_data, out_data);
}
void dense_param_sparse_grad_kernel(
const framework::ExecutionContext &ctx) const override {
VLOG(4) << "[ONEDNN]: sgd_dense_param_kernel<T, SelectedRows>";
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
const auto *grad = ctx.Input<framework::SelectedRows>("Grad");
const auto &grad_value = grad->value();
const auto &grad_rows = grad->rows();
const auto grad_height = grad->height();
const int64_t grad_val_height = static_cast<int64_t>(grad_rows.size());
const auto grad_width = grad_value.numel() / grad_val_height;
const auto *grad_data = grad_value.data<T>();
auto *out_data = param_out->data<T>();
const auto *lr = learning_rate->data<T>();
OneDNNAXPYHandler<T> axpy_handler(grad_width, -lr[0]);
for (size_t i = 0; i < grad_rows.size(); ++i) {
PADDLE_ENFORCE_LT(
grad_rows[i], grad_height,
pplat::errors::OutOfRange(
"Grad rows index value should be less than grad height."
"Got [%s], but expected less than [%s]",
grad_rows[i], grad_height));
const int64_t row = grad_rows[i];
const auto *src = grad_data + i * grad_width;
auto *dst = out_data + row * grad_width;
axpy_handler(src, dst);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(sgd, MKLDNN, pplat::CPUPlace, ops::SGDOneDNNKernel<float>,
ops::SGDOneDNNKernel<pplat::bfloat16>);
......@@ -15,6 +15,9 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/optimizers/sgd_op.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -67,6 +70,26 @@ class SGDOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
#ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
const auto *param_var = ctx.InputVar("Param");
const auto *grad_var = ctx.InputVar("Grad");
// supported cases
bool dense_param_sparse_grad =
param_var->IsType<framework::LoDTensor>() &&
grad_var->IsType<framework::SelectedRows>();
bool dense_param_and_grad = param_var->IsType<framework::LoDTensor>() &&
grad_var->IsType<framework::LoDTensor>();
if (dense_param_sparse_grad || dense_param_and_grad)
return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.device_context());
}
......@@ -106,6 +129,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param");
AddAttr<bool>(
"use_mkldnn",
"(bool, default false) Indicates if MKL-DNN kernel will be used")
.SetDefault(false);
AddComment(R"DOC(
SGD operator
......
......@@ -19,9 +19,6 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/operators/jit/kernels.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/operators/mkldnn/axpy_handler.h"
#endif
#include "paddle/fluid/platform/bfloat16.h"
namespace paddle {
......@@ -142,21 +139,44 @@ struct sgd_dense_param_kernel<
"Got [%s], but expected less than [%s]",
grad_rows[i], grad_height));
const int64_t row = grad_rows[i];
#ifdef PADDLE_WITH_MKLDNN
operators::onednn_handler_axpy(grad_width, -lr[0],
grad_data + i * grad_width,
out_data + row * grad_width);
#else
for (int64_t j = 0; j < grad_width; ++j) {
out_data[row * grad_width + j] -= lr[0] * grad_data[i * grad_width + j];
}
#endif
}
}
};
} // namespace detail
template <typename DeviceContext, typename T>
class SGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
};
template <typename T>
void sgd_op_invoke_dense_param_kernel(const framework::ExecutionContext &ctx) {
class SGDOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *param_var = ctx.InputVar("Param");
if (param_var->IsType<framework::LoDTensor>()) {
invoke_dense_param_kernel(ctx);
} else if (param_var->IsType<framework::SelectedRows>()) {
sparse_param_and_grad_kernel(ctx);
} else {
PADDLE_ENFORCE_EQ(
false, true,
platform::errors::PermissionDenied(
"Unsupported Variable Type of Parameter in SgdOp. Excepted "
"LodTensor or SelectedRows, But received [%s]",
paddle::framework::ToTypeName(param_var->Type())));
}
}
protected:
void invoke_dense_param_kernel(const framework::ExecutionContext &ctx) const {
const auto *param = ctx.Input<framework::Tensor>("Param");
auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
const auto *grad_var = ctx.InputVar("Grad");
......@@ -179,8 +199,7 @@ void sgd_op_invoke_dense_param_kernel(const framework::ExecutionContext &ctx) {
"numel = [%s], ParamOut's numel = [%s]",
grad->numel(), sz));
sgd_dense_param_kernel<
T, framework::VarTypeTrait<framework::LoDTensor>::kId>()(ctx);
dense_param_and_grad_kernel(ctx);
} else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
......@@ -223,38 +242,23 @@ void sgd_op_invoke_dense_param_kernel(const framework::ExecutionContext &ctx) {
"grad_value's numel [%s] and param_out's numel [%s]",
grad_width, param_width));
sgd_dense_param_kernel<
T, framework::VarTypeTrait<framework::SelectedRows>::kId>()(ctx);
dense_param_sparse_grad_kernel(ctx);
} else {
PADDLE_ENFORCE_EQ(
false, true, platform::errors::PermissionDenied(
false, true,
platform::errors::PermissionDenied(
"Unsupported Variable Type of Grad in SgdOp. Excepted "
"LodTensor or SelectedRows, But received [%s]",
paddle::framework::ToTypeName(grad_var->Type())));
}
}
} // namespace detail
template <typename DeviceContext, typename T>
class SGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
};
}
template <typename T>
class SGDOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
void sparse_param_and_grad_kernel(
const framework::ExecutionContext &ctx) const {
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *param_var = ctx.InputVar("Param");
const auto *grad_var = ctx.InputVar("Grad");
if (param_var->IsType<framework::LoDTensor>()) {
detail::sgd_op_invoke_dense_param_kernel<T>(ctx);
} else if (param_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::SelectedRows>(), true,
platform::errors::InvalidArgument(
"When param is SelectedRows, gradient should also "
......@@ -294,14 +298,18 @@ class SGDOpKernel<platform::CPUDeviceContext, T>
lr[0] * grad_data[i * grad_row_width + j];
}
}
} else {
PADDLE_ENFORCE_EQ(
false, true,
platform::errors::PermissionDenied(
"Unsupported Variable Type of Parameter in SgdOp. Excepted "
"LodTensor or SelectedRows, But received [%s]",
paddle::framework::ToTypeName(param_var->Type())));
}
virtual void dense_param_and_grad_kernel(
const framework::ExecutionContext &ctx) const {
detail::sgd_dense_param_kernel<
T, framework::VarTypeTrait<framework::LoDTensor>::kId>()(ctx);
}
virtual void dense_param_sparse_grad_kernel(
const framework::ExecutionContext &ctx) const {
detail::sgd_dense_param_kernel<
T, framework::VarTypeTrait<framework::SelectedRows>::kId>()(ctx);
}
};
......
......@@ -42,6 +42,8 @@ class OptimizerWithMixedPrecision(object):
def __init__(self, optimizer, amp_lists, use_pure_bf16, use_bf16_guard):
self._optimizer = optimizer
if optimizer.type == 'sgd':
optimizer._use_mkldnn = True
self._amp_lists = amp_lists
self._param_grads = None
self._train_program = None
......
......@@ -1305,6 +1305,7 @@ class SGDOptimizer(Optimizer):
grad_clip=grad_clip,
name=name)
self.type = "sgd"
self._use_mkldnn = False
@no_grad
def _append_optimize_op(self, block, param_and_grad):
......@@ -1323,6 +1324,7 @@ class SGDOptimizer(Optimizer):
"Grad": param_and_grad[1],
"LearningRate": lr
},
attrs={"use_mkldnn": self._use_mkldnn},
outputs={"ParamOut": param_and_grad[0]},
stop_gradient=True)
......
......@@ -32,6 +32,7 @@ class TestSGDOpBF16(OpTest):
def setUp(self):
self.op_type = 'sgd'
self.dtype = np.uint16
self.use_mkldnn = True
self.conf()
w = np.random.random((self.h, self.w)).astype('float32')
w_bf16 = convert_float_to_uint16(w)
......@@ -42,6 +43,7 @@ class TestSGDOpBF16(OpTest):
self.inputs = {'Param': w_bf16, 'Grad': g_bf16, 'LearningRate': lr_bf16}
self.outputs = {'ParamOut': w - lr * g}
self.attrs = {'use_mkldnn': self.use_mkldnn}
def conf(self):
self.h = 102
......@@ -53,7 +55,7 @@ class TestSGDOpBF16(OpTest):
@unittest.skipIf(not core.supports_bfloat16(),
'place does not support BF16 evaluation')
class TestSGDOpCase8XBF16(TestSGDOpBF16):
class TestSGDOpBF16Case2(TestSGDOpBF16):
def conf(self):
self.h = 10
self.w = 64
......@@ -142,7 +144,8 @@ class TestSparseGradSGDOpBF16(TestSparseSGDOpBF16):
Param='Param',
Grad='Grad',
ParamOut='Param',
LearningRate='LearningRate')
LearningRate='LearningRate',
use_mkldnn=True)
sgd_op.run(scope, place)
reference = self.ref_optimize(param_array, self.grad_rows, grad_array,
......@@ -194,7 +197,8 @@ class TestSparseGradParamSGDOpBF16(TestSparseSGDOpBF16):
Param='Param',
Grad='Grad',
ParamOut='Param',
LearningRate='LearningRate')
LearningRate='LearningRate',
use_mkldnn=True)
sgd_op.run(scope, place)
reference = self.ref_optimize(param_array, self.grad_rows, grad_array,
......@@ -213,6 +217,11 @@ class TestSparseGradParamSGDOpBF16Case2(TestSparseGradParamSGDOpBF16):
@OpTestTool.skip_if_not_cpu_bf16()
class TestSGDOpBF16API(unittest.TestCase):
@classmethod
def setUpClass(cls):
np.random.seed(12345)
fluid.set_flags({'FLAGS_use_mkldnn': True})
def setUp(self):
self.sample_count = 20
self.value = np.random.random()
......@@ -222,9 +231,7 @@ class TestSGDOpBF16API(unittest.TestCase):
self.y_shape = (32, 16)
self.learning_rate = 0.1
np.random.seed(12345)
self._set_initializer()
fluid.set_flags({'FLAGS_use_mkldnn': True})
def _fp322bf16(self, val: np.float32):
return np.uint16(struct.unpack('<I', struct.pack('<f', val))[0] >> 16)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册