未验证 提交 c923e6c9 编写于 作者: C Chen Weihang 提交者: GitHub

Adapting device-specific Extra Attributes for the PHI kernel (#46342)

* add extra attr property set

* add type_info for all context

* add onednn context to all context

* fix context compile error

* simplify conv kernel args

* pass runtime attr into dev_ctx

* fix marco error

* clear conv_grad_kernel extra args

* merge conv_grad_grad into conv_grad

* clear conv2d_grad_grad extra attrs

* clear yaml and eager extra attr

* fix conv1d error

* change to thread local

* fix npu compile failed

* try to fix windows compile failed

* add conv2d onednn phi kernel

* fix ci bugs (#36)

* fix compile bugs (#38)

* fix extra input transform bug (#39)

* support dynamic created attr (#40)

* reset extra info gen code

* rm conv_grad_grad kernel

* reimpl pass attr adapting

* add int attr support

* remove vector inputnames creating

* fix map at error

* Update paddle/phi/kernels/onednn/conv_grad_kernel.cc
Co-authored-by: NSławomir Siwek <slawomir.siwek@intel.com>

* remove useless extra attrs

* replace mkldnn_engine by onednn_engine
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
Co-authored-by: NSławomir Siwek <slawomir.siwek@intel.com>
上级 f82d7e3c
......@@ -24,10 +24,7 @@ paddle::experimental::Tensor conv2d_ad_func(
const paddle::experimental::Tensor& filter,
std::vector<int> strides,
std::vector<int> paddings,
std::string paddding_algorithm,
int groups,
std::string padding_algorithm,
std::vector<int> dilations,
std::string data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search);
int groups,
std::string data_format);
......@@ -29,13 +29,10 @@ paddle::experimental::Tensor conv2d_ad_func(
const paddle::experimental::Tensor& filter,
std::vector<int> strides,
std::vector<int> paddings,
std::string paddding_algorithm,
int groups,
std::string padding_algorithm,
std::vector<int> dilations,
std::string data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search) {
int groups,
std::string data_format) {
// Dygraph Record Event
paddle::platform::RecordEvent dygraph_entrance_record_event(
"conv2d dygraph", paddle::platform::TracerEventType::Operator, 1);
......@@ -64,13 +61,10 @@ paddle::experimental::Tensor conv2d_ad_func(
new_filter,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search);
groups,
data_format);
}
}
......@@ -92,13 +86,10 @@ paddle::experimental::Tensor conv2d_ad_func(
filter,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search);
groups,
data_format);
transformer->SetOutTensorLayout(&out);
if (need_tune) {
egr::Controller::Instance().EnableLayoutAutoTune();
......@@ -119,13 +110,10 @@ paddle::experimental::Tensor conv2d_ad_func(
filter,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search);
groups,
data_format);
// Check NaN and Inf if needed
if (FLAGS_check_nan_inf) {
egr::CheckTensorHasNanOrInf("conv2d", api_result);
......@@ -157,13 +145,10 @@ paddle::experimental::Tensor conv2d_ad_func(
// SetAttributes if needed
grad_node->SetAttributestrides(strides);
grad_node->SetAttributepaddings(paddings);
grad_node->SetAttributepaddding_algorithm(paddding_algorithm);
grad_node->SetAttributepadding_algorithm(padding_algorithm);
grad_node->SetAttributegroups(groups);
grad_node->SetAttributedilations(dilations);
grad_node->SetAttributedata_format(data_format);
grad_node->SetAttributeuse_addto(use_addto);
grad_node->SetAttributeworkspace_size_MB(workspace_size_MB);
grad_node->SetAttributeexhaustive_search(exhaustive_search);
// Set TensorWrappers for Forward Inputs if needed
grad_node->SetTensorWrapperinput(input);
grad_node->SetTensorWrapperfilter(filter);
......
......@@ -46,13 +46,10 @@ Conv2dGradNodeFinal::operator()(
auto& grad_out = hooked_grads[0][0];
auto& strides = this->strides_;
auto& paddings = this->paddings_;
auto& paddding_algorithm = this->paddding_algorithm_;
auto& padding_algorithm = this->padding_algorithm_;
auto& groups = this->groups_;
auto& dilations = this->dilations_;
auto& data_format = this->data_format_;
auto& use_addto = this->use_addto_;
auto& workspace_size_MB = this->workspace_size_MB_;
auto& exhaustive_search = this->exhaustive_search_;
// Prepare Grad function call
const auto& out_metas = OutputMeta();
......@@ -87,13 +84,10 @@ Conv2dGradNodeFinal::operator()(
grad_out,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
api_output_0,
api_output_1);
// Check NaN and Inf id needed
......@@ -134,13 +128,10 @@ Conv2dGradNodeFinal::operator()(
// SetAttributes if needed
grad_node->SetAttributestrides(strides);
grad_node->SetAttributepaddings(paddings);
grad_node->SetAttributepaddding_algorithm(paddding_algorithm);
grad_node->SetAttributepadding_algorithm(padding_algorithm);
grad_node->SetAttributegroups(groups);
grad_node->SetAttributedilations(dilations);
grad_node->SetAttributedata_format(data_format);
grad_node->SetAttributeuse_addto(use_addto);
grad_node->SetAttributeworkspace_size_MB(workspace_size_MB);
grad_node->SetAttributeexhaustive_search(exhaustive_search);
// Set TensorWrappers for Forward Inputs if needed
grad_node->SetTensorWrapperinput(input);
grad_node->SetTensorWrapperfilter(filter);
......@@ -215,13 +206,10 @@ Conv2dDoubleGradNodeFinal::operator()(
auto& strides = this->strides_;
auto& paddings = this->paddings_;
auto& paddding_algorithm = this->paddding_algorithm_;
auto& padding_algorithm = this->padding_algorithm_;
auto& groups = this->groups_;
auto& dilations = this->dilations_;
auto& data_format = this->data_format_;
auto& use_addto = this->use_addto_;
auto& workspace_size_MB = this->workspace_size_MB_;
auto& exhaustive_search = this->exhaustive_search_;
// Prepare Grad function call
const auto& out_metas = OutputMeta();
......@@ -261,13 +249,10 @@ Conv2dDoubleGradNodeFinal::operator()(
grad_filter_grad_optional,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
api_output_0,
api_output_1,
api_output_2);
......
......@@ -63,8 +63,8 @@ class Conv2dGradNodeFinal : public egr::GradNodeBase {
void SetAttributepaddings(const std::vector<int>& paddings) {
paddings_ = paddings;
}
void SetAttributepaddding_algorithm(const std::string& paddding_algorithm) {
paddding_algorithm_ = paddding_algorithm;
void SetAttributepadding_algorithm(const std::string& padding_algorithm) {
padding_algorithm_ = padding_algorithm;
}
void SetAttributegroups(const int& groups) { groups_ = groups; }
void SetAttributedilations(const std::vector<int>& dilations) {
......@@ -73,13 +73,6 @@ class Conv2dGradNodeFinal : public egr::GradNodeBase {
void SetAttributedata_format(const std::string& data_format) {
data_format_ = data_format;
}
void SetAttributeuse_addto(const bool& use_addto) { use_addto_ = use_addto; }
void SetAttributeworkspace_size_MB(const int& workspace_size_MB) {
workspace_size_MB_ = workspace_size_MB;
}
void SetAttributeexhaustive_search(const bool& exhaustive_search) {
exhaustive_search_ = exhaustive_search;
}
private:
// TensorWrappers
......@@ -89,13 +82,10 @@ class Conv2dGradNodeFinal : public egr::GradNodeBase {
// Attributes
std::vector<int> strides_;
std::vector<int> paddings_;
std::string paddding_algorithm_;
std::string padding_algorithm_;
int groups_;
std::vector<int> dilations_;
std::string data_format_;
bool use_addto_;
int workspace_size_MB_;
bool exhaustive_search_;
};
class Conv2dDoubleGradNodeFinal : public egr::GradNodeBase {
......@@ -146,8 +136,8 @@ class Conv2dDoubleGradNodeFinal : public egr::GradNodeBase {
void SetAttributepaddings(const std::vector<int>& paddings) {
paddings_ = paddings;
}
void SetAttributepaddding_algorithm(const std::string& paddding_algorithm) {
paddding_algorithm_ = paddding_algorithm;
void SetAttributepadding_algorithm(const std::string& padding_algorithm) {
padding_algorithm_ = padding_algorithm;
}
void SetAttributegroups(const int& groups) { groups_ = groups; }
void SetAttributedilations(const std::vector<int>& dilations) {
......@@ -156,13 +146,6 @@ class Conv2dDoubleGradNodeFinal : public egr::GradNodeBase {
void SetAttributedata_format(const std::string& data_format) {
data_format_ = data_format;
}
void SetAttributeuse_addto(const bool& use_addto) { use_addto_ = use_addto; }
void SetAttributeworkspace_size_MB(const int& workspace_size_MB) {
workspace_size_MB_ = workspace_size_MB;
}
void SetAttributeexhaustive_search(const bool& exhaustive_search) {
exhaustive_search_ = exhaustive_search;
}
private:
// TensorWrappers
......@@ -173,13 +156,10 @@ class Conv2dDoubleGradNodeFinal : public egr::GradNodeBase {
// Attributes
std::vector<int> strides_;
std::vector<int> paddings_;
std::string paddding_algorithm_;
std::string padding_algorithm_;
int groups_;
std::vector<int> dilations_;
std::string data_format_;
bool use_addto_;
int workspace_size_MB_;
bool exhaustive_search_;
};
class AddNGradNodeFinal : public egr::GradNodeBase {
......
......@@ -32,8 +32,8 @@
#include <valarray>
#include <vector>
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/expect.h"
namespace paddle {
namespace framework {
......
......@@ -30,7 +30,7 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/expect.h"
#include "paddle/phi/core/expect.h"
namespace paddle {
namespace framework {
......
......@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/isfinite_op.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -2269,7 +2270,8 @@ Scope* OperatorWithKernel::PrepareData(
}
std::unique_ptr<OpKernelType> new_expected_kernel_key = nullptr;
if (run_phi_kernel_ && in_def->backend != phi::Backend::ALL_BACKEND) {
if (run_phi_kernel_ && in_def != nullptr &&
in_def->backend != phi::Backend::ALL_BACKEND) {
auto tensor_backend = phi::TransToPhiBackend(tensor_in->place());
if ((in_def->backend != tensor_backend &&
(in_def->backend != phi::Backend::GPUDNN ||
......@@ -2388,7 +2390,6 @@ Scope* OperatorWithKernel::PrepareData(
input_names.size(),
input_defs.size()));
for (size_t i = 0; i < input_defs.size(); ++i) {
const auto& input_defs = phi_kernel_->args_def().input_defs();
auto& in_def = input_defs.at(i);
std::string input_name = input_names[i];
auto iter = ctx->inputs.find(input_name);
......@@ -2400,6 +2401,22 @@ Scope* OperatorWithKernel::PrepareData(
no_buffer_ins && no_buffer_ins->count(input_name) > 0;
prepare_input_data(input_name, &ins_vector, &in_def, should_skip_input);
}
#ifdef PADDLE_WITH_MKLDNN
// For input that is Extra, only MKLDNN will use Extra Inputs
auto& extra_input_names =
paddle::operators::ExtraInfoUtils::Instance().GetExtraInputNamesMap(
Type());
for (const auto& input_name : extra_input_names) {
auto iter = ctx->inputs.find(input_name);
if (iter == ctx->inputs.end()) {
continue;
}
bool should_skip_input =
no_buffer_ins && no_buffer_ins->count(input_name) > 0;
std::vector<Variable*>& input_vars = iter->second;
prepare_input_data(input_name, &input_vars, nullptr, should_skip_input);
}
#endif
} else {
for (auto& var_name_item : Inputs()) {
bool should_skip_input =
......@@ -2699,6 +2716,65 @@ phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
return (*arg_map_fn_)(arg_mapping_ctx);
}
static void SetDnnAttrIntoDeviceContext(
phi::DeviceContext* dev_ctx,
const Attribute& attr,
const std::string& attr_name,
const operators::ExtraAttrPropertySet& attr_propertys) {
#ifdef PADDLE_WITH_MKLDNN
if (phi::OneDNNContext::classof(dev_ctx) &&
attr_propertys.Support(operators::ExtraAttrProperty::ONEDNN)) {
VLOG(4) << "Runtime attr `" << attr_name << "` is passed to OneDNNContext.";
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
switch (AttrTypeID(attr)) {
case proto::AttrType::FLOAT:
one_dnn_ctx->SetDnnAttr(attr_name, PADDLE_GET_CONST(float, attr));
break;
case proto::AttrType::INT:
one_dnn_ctx->SetDnnAttr(attr_name, PADDLE_GET_CONST(int, attr));
break;
case proto::AttrType::STRING:
one_dnn_ctx->SetDnnAttr(attr_name, PADDLE_GET_CONST(std::string, attr));
break;
case proto::AttrType::INTS:
one_dnn_ctx->SetDnnAttr(attr_name,
PADDLE_GET_CONST(std::vector<int>, attr));
break;
case proto::AttrType::FLOATS:
one_dnn_ctx->SetDnnAttr(attr_name,
PADDLE_GET_CONST(std::vector<float>, attr));
break;
case proto::AttrType::BOOLEAN:
one_dnn_ctx->SetDnnAttr(attr_name, PADDLE_GET_CONST(bool, attr));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported Attribute value type `%s` for phi.",
platform::demangle(attr.type().name())));
}
}
#endif
#ifdef PADDLE_WITH_CUDA
if (phi::GPUContext::classof(dev_ctx) &&
attr_propertys.Support(operators::ExtraAttrProperty::GPUDNN)) {
VLOG(4) << "Runtime attr `" << attr_name << "` is passed to GPUDNNContext.";
phi::GPUContext* gpu_dnn_ctx = static_cast<phi::GPUContext*>(dev_ctx);
switch (AttrTypeID(attr)) {
case proto::AttrType::INT:
gpu_dnn_ctx->SetDnnAttr(attr_name, PADDLE_GET_CONST(int, attr));
break;
case proto::AttrType::BOOLEAN:
gpu_dnn_ctx->SetDnnAttr(attr_name, PADDLE_GET_CONST(bool, attr));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported Attribute value type `%s` for phi.",
platform::demangle(attr.type().name())));
}
}
#endif
}
void OperatorWithKernel::BuildPhiKernelContext(
const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx,
......@@ -2713,6 +2789,15 @@ void OperatorWithKernel::BuildPhiKernelContext(
auto attr_defs = phi_kernel_->args_def().attribute_defs();
auto output_defs = phi_kernel_->args_def().output_defs();
#if defined(PADDLE_WITH_MKLDNN)
if (phi::OneDNNContext::classof(dev_ctx)) {
// Onednn holds this op's variable's name and init them here.
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
one_dnn_ctx->SetInputsName(Inputs());
one_dnn_ctx->SetOutputsName(Outputs());
}
#endif
PADDLE_ENFORCE_EQ(input_names.size(),
input_defs.size(),
platform::errors::InvalidArgument(
......@@ -2992,6 +3077,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
} break;
default: {
if (attr_iter == Attrs().end()) {
// TODO(chenweihang): remove this backup searching later
attr_iter = RuntimeAttrs().find(attr_names[i]);
PADDLE_ENFORCE_NE(attr_iter,
RuntimeAttrs().end(),
......@@ -3075,6 +3161,63 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
}
VLOG(4) << "Done attributes";
// For compatible with Op with extra attrs for specific backend
#if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA)
auto& runtime_attrs = RuntimeAttrs();
for (const auto& attr_iter : runtime_attrs) {
auto& attr_name = attr_iter.first;
auto& attr = attr_iter.second;
auto attr_propertys = paddle::operators::GetExtraAttrPropertys(attr_name);
SetDnnAttrIntoDeviceContext(dev_ctx, attr, attr_name, attr_propertys);
}
// TODO(chenweihang): Since the pass will still `SetAttr` in the OpDesc,
// we try to add these Attrs to the RuntimeAttrs, but these OpDesc will lose
// the RuntimeAttrs information in the process of converting the Graph to
// the Program, so additional record configuration will be introduced,
// which increases the The cost of development and understanding, so we
// still use Attrs to get and the attributes set by these passes from Attrs
// for the time being. In the future, it is necessary to clarify the
// positioning of RuntimeAttrs and expand related functions.
auto& attrs = Attrs();
for (const auto& attr_iter : attrs) {
auto& attr_name = attr_iter.first;
auto& attr = attr_iter.second;
auto attr_propertys = paddle::operators::GetExtraAttrPropertys(attr_name);
SetDnnAttrIntoDeviceContext(dev_ctx, attr, attr_name, attr_propertys);
}
VLOG(4) << "Done runtime attributes";
#endif
// For compatible with Op with extra input for onednn backend
#ifdef PADDLE_WITH_MKLDNN
if (phi::OneDNNContext::classof(dev_ctx)) {
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
auto& extra_input_names =
paddle::operators::ExtraInfoUtils::Instance().GetExtraInputNamesMap(
Type());
for (const auto& input_name : extra_input_names) {
auto it = ctx.inputs.find(input_name);
if (it == ctx.inputs.end() || it->second.size() == 0) {
one_dnn_ctx->SetDnnInput(input_name, nullptr);
} else {
auto ins_vector = it->second;
PADDLE_ENFORCE_EQ(
ins_vector.size(),
1UL,
phi::errors::InvalidArgument(
"OneDNN's extra input only allows one input tensor."));
auto* var = ins_vector[0];
PADDLE_ENFORCE_EQ(var->IsType<phi::DenseTensor>(),
true,
phi::errors::InvalidArgument(
"OneDNN's extra input only can be DenseTensor."));
one_dnn_ctx->SetDnnInput(input_name, &(var->Get<phi::DenseTensor>()));
}
}
}
VLOG(4) << "Done runtime extra inputs";
#endif
}
} // namespace framework
......
......@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include "paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h"
#include "paddle/phi/core/expect.h"
namespace paddle {
namespace operators {
......
......@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/operators/fused/fusion_lstm_op.h"
#include "paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h"
#include "paddle/phi/core/expect.h"
namespace paddle {
namespace operators {
......
......@@ -14,11 +14,11 @@
#include <tuple>
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/core/expect.h"
#include "paddle/phi/core/visit_type.h"
......@@ -1184,20 +1184,6 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
ops::ConvMKLDNNOpKernel<float>,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::ConvMKLDNNOpKernel<uint8_t>,
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL(conv2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::ConvMKLDNNGradOpKernel<float>,
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(depthwise_conv2d,
MKLDNN,
::paddle::platform::CPUPlace,
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/core/expect.h"
namespace paddle {
namespace operators {
......
......@@ -36,7 +36,7 @@ PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN);
USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP_ITSELF(conv2d);
USE_OP_DEVICE_KERNEL(conv2d, MKLDNN);
PD_DECLARE_KERNEL(conv2d, OneDNN, ONEDNN);
namespace paddle {
namespace operators {
......
......@@ -14,11 +14,137 @@
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
namespace paddle {
namespace operators {
// This file is to be compatible with the bad design and
// implementation of fluid in the past
// Many operators in fluid have extra attributes, which are generally added
// to implement some specific kernel selection and to meet the specialization
// needs of a specific operation library like mkldnn or cudnn
enum class ExtraAttrProperty : uint8_t {
// The attributes that are no longer used by any scene
DEPRECATED = 0,
// The attributes used for framework execution scheduling,
// such as `use_mkldnn`, `use_cudnn`, no need to save
SCHEDULE,
// The attributes for ONEDNN only, can be saved in OneDNNContext
ONEDNN,
// The attributes for ONEDNN only, can be saved in GPUContext
GPUDNN,
// Add necessary properties as needed
};
class ExtraAttrPropertySet final {
public:
constexpr ExtraAttrPropertySet() : bitset_(0) {}
constexpr ExtraAttrPropertySet(ExtraAttrProperty e) // NOLINT
: bitset_(e == ExtraAttrProperty::DEPRECATED
? 0
: 1ULL << (static_cast<uint8_t>(e) - 1)) {}
inline uint64_t bitset() const { return bitset_; }
bool inline Support(ExtraAttrProperty e) const {
// DEPRECATED ExtraAttr always return false
return static_cast<bool>(bitset_ & ExtraAttrPropertySet(e).bitset());
}
bool IsEmpty() const { return bitset_ == 0; }
ExtraAttrPropertySet operator|(const ExtraAttrPropertySet& other) const {
return ExtraAttrPropertySet(bitset_ | other.bitset());
}
ExtraAttrPropertySet operator&(const ExtraAttrPropertySet& other) const {
return ExtraAttrPropertySet(bitset_ & other.bitset());
}
ExtraAttrPropertySet operator-(const ExtraAttrPropertySet& other) const {
return ExtraAttrPropertySet(bitset_ & ~other.bitset());
}
ExtraAttrPropertySet operator^(const ExtraAttrPropertySet& other) const {
return ExtraAttrPropertySet(bitset_ ^ other.bitset());
}
bool operator==(const ExtraAttrPropertySet& other) const {
return bitset_ == other.bitset();
}
private:
constexpr ExtraAttrPropertySet(uint64_t bitset) : bitset_(bitset) {}
uint64_t bitset_;
};
const std::unordered_map<std::string, ExtraAttrPropertySet>
extra_attr_properties = {
// DEPRECATED attributes
{"use_quantizer", ExtraAttrProperty::DEPRECATED},
// SCHEDULE attributes
{"use_cudnn", ExtraAttrProperty::SCHEDULE},
{"use_mkldnn", ExtraAttrProperty::SCHEDULE},
// ONEDNN dedicated attributes
{"Bias", ExtraAttrProperty::ONEDNN},
{"data_format", ExtraAttrProperty::ONEDNN},
{"force_fp32_output", ExtraAttrProperty::ONEDNN},
{"fuse_activation", ExtraAttrProperty::ONEDNN},
{"fuse_activation_type", ExtraAttrProperty::ONEDNN},
{"fuse_activation_alpha", ExtraAttrProperty::ONEDNN},
{"fuse_activation_beta", ExtraAttrProperty::ONEDNN},
{"fuse_activation_scale", ExtraAttrProperty::ONEDNN},
{"fuse_alpha", ExtraAttrProperty::ONEDNN},
{"fuse_beta", ExtraAttrProperty::ONEDNN},
{"fuse_relu", ExtraAttrProperty::ONEDNN},
{"fuse_residual_connection", ExtraAttrProperty::ONEDNN},
{"fuse_with_relu", ExtraAttrProperty::ONEDNN},
{"fused_reshape_Out", ExtraAttrProperty::ONEDNN},
{"fused_transpose_Out", ExtraAttrProperty::ONEDNN},
{"fused_reshape_X", ExtraAttrProperty::ONEDNN},
{"fused_reshape_Y", ExtraAttrProperty::ONEDNN},
{"fused_transpose_X", ExtraAttrProperty::ONEDNN},
{"fused_transpose_Y", ExtraAttrProperty::ONEDNN},
{"mkldnn_data_type", ExtraAttrProperty::ONEDNN},
{"ResidualData", ExtraAttrProperty::ONEDNN},
{"scale_x", ExtraAttrProperty::ONEDNN},
{"scale_y", ExtraAttrProperty::ONEDNN},
{"scale_out", ExtraAttrProperty::ONEDNN},
{"Scale_in", ExtraAttrProperty::ONEDNN},
{"Scale_in_eltwise", ExtraAttrProperty::ONEDNN},
{"Scale_x", ExtraAttrProperty::ONEDNN},
{"Scale_y", ExtraAttrProperty::ONEDNN},
{"Scale_out", ExtraAttrProperty::ONEDNN},
{"Scale_weights", ExtraAttrProperty::ONEDNN},
{"x_data_format", ExtraAttrProperty::ONEDNN},
{"y_data_format", ExtraAttrProperty::ONEDNN},
// ONEDNN pass dedicated attributes
{"Activation_scale", ExtraAttrProperty::ONEDNN},
{"Bias_scales", ExtraAttrProperty::ONEDNN},
{"Output_shift_scale", ExtraAttrProperty::ONEDNN},
{"Sum_scale", ExtraAttrProperty::ONEDNN},
// GPUDNN dedicated attributes
{"exhaustive_search", ExtraAttrProperty::GPUDNN},
{"fuse_relu_before_depthwise_conv", ExtraAttrProperty::GPUDNN},
{"use_addto", ExtraAttrProperty::GPUDNN},
{"workspace_size_MB", ExtraAttrProperty::GPUDNN},
// Mixed-use attributes
{"is_test",
ExtraAttrPropertySet(ExtraAttrProperty::ONEDNN) |
ExtraAttrPropertySet(ExtraAttrProperty::GPUDNN)},
};
inline ExtraAttrPropertySet GetExtraAttrPropertys(
const std::string& attr_name) {
auto iter = extra_attr_properties.find(attr_name);
if (iter != extra_attr_properties.end()) {
return iter->second;
}
return ExtraAttrPropertySet();
}
template <typename T>
struct ExtraAttrChecker {
ExtraAttrChecker(const std::string& attr_name, T default_value)
......@@ -71,6 +197,15 @@ class ExtraInfoUtils {
return empty_extra_attrs_checker_;
}
const std::vector<std::string>& GetExtraInputNamesMap(
const std::string& op_type) const {
auto iter = g_extra_input_names_map_.find(op_type);
if (iter != g_extra_input_names_map_.end()) {
return iter->second;
}
return empty_extra_input_names_;
}
private:
ExtraInfoUtils();
......@@ -83,6 +218,12 @@ class ExtraInfoUtils {
g_extra_attrs_checker_;
std::vector<std::function<void(framework::AttributeMap*, bool)>>
empty_extra_attrs_checker_{};
// TODO(chenweihang): move these extra inputs into op_compat.yaml
std::unordered_map<std::string, std::vector<std::string>>
g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}},
{"conv2d_grad", {"Bias"}}};
std::vector<std::string> empty_extra_input_names_;
};
} // namespace operators
......
......@@ -89,7 +89,9 @@ class MLUContext {
DISABLE_COPY_AND_ASSIGN(MLUContext);
};
class MLUDeviceContext : public DeviceContext {
class MLUDeviceContext
: public DeviceContext,
public phi::TypeInfoTraits<DeviceContext, MLUDeviceContext> {
public:
explicit MLUDeviceContext(MLUPlace place);
virtual ~MLUDeviceContext();
......@@ -148,6 +150,8 @@ class MLUDeviceContext : public DeviceContext {
return thread_ctx_.at(this);
}
static const char* name() { return "MLUDeviceContext"; }
private:
int compute_capability_;
int driver_version_;
......
......@@ -19,7 +19,6 @@ limitations under the License. */
#include <set>
#include "glog/logging.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
......@@ -28,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/expect.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
......
......@@ -144,7 +144,9 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
// Graphcore IPU
#ifdef PADDLE_WITH_IPU
class IPUDeviceContext : public DeviceContext {
class IPUDeviceContext
: public DeviceContext,
public phi::TypeInfoTraits<DeviceContext, IPUDeviceContext> {
public:
IPUDeviceContext() = delete;
explicit IPUDeviceContext(IPUPlace place);
......@@ -154,6 +156,8 @@ class IPUDeviceContext : public DeviceContext {
/*! \brief Wait for all operations completion in the stream. */
void Wait() const override;
static const char* name() { return "IPUDeviceContext"; }
private:
IPUPlace place_;
};
......@@ -188,7 +192,9 @@ struct DefaultDeviceContextType<platform::XPUPlace> {
#endif
#ifdef PADDLE_WITH_ASCEND_CL
class NPUDeviceContext : public DeviceContext {
class NPUDeviceContext
: public DeviceContext,
public phi::TypeInfoTraits<DeviceContext, NPUDeviceContext> {
public:
explicit NPUDeviceContext(NPUPlace place);
virtual ~NPUDeviceContext();
......@@ -224,6 +230,8 @@ class NPUDeviceContext : public DeviceContext {
// void WaitStreamCallback() const { return stream_->WaitCallback(); }
static const char* name() { return "NPUDeviceContext"; }
private:
NPUPlace place_;
aclrtContext context_;
......@@ -248,7 +256,9 @@ struct DefaultDeviceContextType<platform::NPUPlace> {
};
// Currently, NPUPinnedDeviceContext is only used to data copying.
class NPUPinnedDeviceContext : public DeviceContext {
class NPUPinnedDeviceContext
: public DeviceContext,
public phi::TypeInfoTraits<DeviceContext, NPUPinnedDeviceContext> {
public:
NPUPinnedDeviceContext();
explicit NPUPinnedDeviceContext(NPUPinnedPlace place);
......@@ -257,6 +267,8 @@ class NPUPinnedDeviceContext : public DeviceContext {
Eigen::DefaultDevice* eigen_device() const;
static const char* name() { return "NPUPinnedDeviceContext"; }
private:
NPUPinnedPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
......@@ -276,7 +288,9 @@ struct DefaultDeviceContextType<platform::CUDAPlace> {
};
// Currently, CUDAPinnedDeviceContext is only used to data copying.
class CUDAPinnedDeviceContext : public DeviceContext {
class CUDAPinnedDeviceContext
: public DeviceContext,
public phi::TypeInfoTraits<DeviceContext, CUDAPinnedDeviceContext> {
public:
CUDAPinnedDeviceContext();
explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);
......@@ -285,6 +299,8 @@ class CUDAPinnedDeviceContext : public DeviceContext {
Eigen::DefaultDevice* eigen_device() const;
static const char* name() { return "CUDAPinnedDeviceContext"; }
private:
CUDAPinnedPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
......
......@@ -122,77 +122,80 @@ using namespace ::phi::enforce; // NOLINT
#endif
/*
* Summary: This PADDLE_GET(_**) series macros are used to call paddle::get
* safely. paddle::get is not a completely safe api, although it will not
* go wrong in most cases, but in extreme cases, it may fail and directly
* throw a paddle::bad_variant_access const exception, without any stack
*information.
* This kind of problems is difficult to debug, so add these macros to
* enrich paddle::get error information. At the same time, we restrict
* the direct use of paddle::get by CI rule.
* Summary: This macro is used to get Variable or internal type
* data (such as LoDTensor or SelectedRows) of the Input and
* Output in op, generally used when call scope.FindVar(Input/
* Output("Name")) or ctx.Input<LoDTensor>().
* Firstly this macro check whether the obtained pointer is null,
* and then return data if it is not null.
*
* Note: This macro is only suitable for specific scenarios and
* does not intended to be widely used. If it cannot meet the
* requirements, please use other PADDLE_ENFORCE** check macro.
*
* Parameters:
*     __TYPE: the target variable type
* __VALUE: the target variable to get
*     __PTR: pointer
* __ROLE: (string), Input or Output
* __NAME: (string), Input or Output name
* __OP_TYPE: (string), the op type
*
* Examples:
* - unsafe writing: int x = paddle::get<int>(y);
* - safe writing: int x = PADDLE_GET(int, y);
* Return: The data pointed to by the pointer.
*
* Note: GCC 4.8 cannot select right overloaded function here, so need
* to define different functions and macros here, after we upgreade
* CI gcc version, we can only define one PADDLE_GET macro.
* Examples:
* GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X", "Mul");
*/
namespace details {
using namespace phi::enforce::details; // NOLINT
#define DEFINE_SAFE_PADDLE_GET( \
__InputType, __OutputType, __OutputTypePtr, __FuncName) \
template <typename OutputType, typename InputType> \
auto __FuncName( \
__InputType input, const char* expression, const char* file, int line) \
->typename std::conditional<std::is_pointer<InputType>::value, \
__OutputTypePtr, \
__OutputType>::type { \
try { \
return paddle::get<OutputType>(input); \
} catch (paddle::bad_variant_access const&) { \
HANDLE_THE_ERROR \
throw ::phi::enforce::EnforceNotMet( \
phi::errors::InvalidArgument( \
"paddle::get failed, cannot get value " \
"(%s) by type %s, its type is %s.", \
expression, \
phi::enforce::demangle(typeid(OutputType).name()), \
phi::enforce::demangle(input.type().name())), \
file, \
line); \
END_HANDLE_THE_ERROR \
#define GET_DATA_SAFELY(__PTR, __ROLE, __NAME, __OP_TYPE) \
(([&]() -> std::add_lvalue_reference<decltype(*(__PTR))>::type { \
auto* __ptr = (__PTR); \
if (UNLIKELY(nullptr == __ptr)) { \
auto __summary__ = phi::errors::NotFound( \
"Unable to get %s data of %s %s in operator %s. " \
"Possible reasons are:\n" \
" 1. The %s is not the %s of operator %s;\n" \
" 2. The %s has no corresponding variable passed in;\n" \
" 3. The %s corresponding variable is not initialized.", \
phi::demangle( \
typeid(std::add_lvalue_reference<decltype(*__ptr)>::type) \
.name()), \
__ROLE, \
__NAME, \
__OP_TYPE, \
__NAME, \
__ROLE, \
__OP_TYPE, \
__NAME, \
__NAME); \
auto __message__ = ::paddle::string::Sprintf( \
"%s\n [Hint: pointer " #__PTR " should not be null.]", \
__summary__.error_message()); \
__THROW_ERROR_INTERNAL__( \
phi::ErrorSummary(__summary__.code(), __message__)); \
} \
}
return *__ptr; \
})())
DEFINE_SAFE_PADDLE_GET(InputType&, OutputType&, OutputType*, SafeBoostGet);
DEFINE_SAFE_PADDLE_GET(const InputType&,
const OutputType&,
const OutputType*,
SafeBoostGetConst);
DEFINE_SAFE_PADDLE_GET(InputType&&,
OutputType,
OutputType*,
SafeBoostGetMutable);
} // namespace details
#define PADDLE_GET(__TYPE, __VALUE) \
paddle::platform::details::SafeBoostGet<__TYPE>( \
__VALUE, #__VALUE, __FILE__, __LINE__)
#define PADDLE_GET_CONST(__TYPE, __VALUE) \
paddle::platform::details::SafeBoostGetConst<__TYPE>( \
__VALUE, #__VALUE, __FILE__, __LINE__)
#define PADDLE_GET_MUTABLE(__TYPE, __VALUE) \
paddle::platform::details::SafeBoostGetMutable<__TYPE>( \
__VALUE, #__VALUE, __FILE__, __LINE__)
/*
* Summary: This macro is used to check whether op has specified
* Input or Output Variables. Because op's Input and Output
* checking are written similarly, so abstract this macro.
*
* Parameters:
*     __EXPR: (bool), the bool expression
* __ROLE: (string), Input or Output
* __NAME: (string), Input or Output name
* __OP_TYPE: (string), the op type
*
* Examples:
* OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Mul");
*/
#define OP_INOUT_CHECK(__EXPR, __ROLE, __NAME, __OP_TYPE) \
do { \
PADDLE_ENFORCE_EQ( \
__EXPR, \
true, \
phi::errors::NotFound( \
"No %s(%s) found for %s operator.", __ROLE, __NAME, __OP_TYPE)); \
} while (0)
/** OTHER EXCEPTION AND ENFORCE **/
......
......@@ -528,10 +528,9 @@ struct CannotToStringType {
};
TEST(enforce, cannot_to_string_type) {
static_assert(
!paddle::platform::details::CanToString<CannotToStringType>::kValue,
static_assert(!phi::enforce::details::CanToString<CannotToStringType>::kValue,
"CannotToStringType must not be converted to string");
static_assert(paddle::platform::details::CanToString<int>::kValue,
static_assert(phi::enforce::details::CanToString<int>::kValue,
"int can be converted to string");
CannotToStringType obj1(3), obj2(4), obj3(3);
......
......@@ -312,8 +312,8 @@
func : conj
- backward_op : conv2d_grad
forward : conv2d (Tensor input, Tensor filter, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
forward : conv2d (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format)
output : Tensor(input_grad), Tensor(filter_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
......@@ -324,8 +324,8 @@
backward : conv2d_grad_grad
- backward_op : conv2d_grad_grad
forward : conv2d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
forward : conv2d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format)
output : Tensor(input_grad), Tensor(filter_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
......@@ -357,8 +357,8 @@
backward : conv2d_transpose_double_grad
- backward_op : conv3d_grad
forward : conv3d (Tensor input, Tensor filter, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
forward : conv3d (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
output : Tensor(input_grad), Tensor(filter_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
......@@ -369,8 +369,8 @@
backward : conv3d_grad_grad
- backward_op : conv3d_grad_grad
forward : conv3d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
forward : conv3d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
output : Tensor(input_grad), Tensor(filter_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
......@@ -439,21 +439,21 @@
optional : mask
- backward_op : depthwise_conv2d_grad
forward : depthwise_conv2d (Tensor input, Tensor filter, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search, bool fuse_relu, bool use_gpudnn) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search, bool fuse_relu, bool use_gpudnn)
forward : depthwise_conv2d (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search, bool fuse_relu, bool use_gpudnn) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search, bool fuse_relu, bool use_gpudnn)
output : Tensor(input_grad), Tensor(filter_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [input, filter]
kernel :
func : depthwise_conv2d_grad
param : [input, filter, out_grad, strides, paddings, paddding_algorithm, groups, dilations, data_format, use_addto, workspace_size_MB, exhaustive_search, fuse_relu]
param : [input, filter, out_grad, strides, paddings, padding_algorithm, groups, dilations, data_format, use_addto, workspace_size_MB, exhaustive_search, fuse_relu]
use_gpudnn : use_gpudnn
backward : depthwise_conv2d_grad_grad
- backward_op : depthwise_conv2d_grad_grad
forward : depthwise_conv2d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search, bool fuse_relu, bool use_gpudnn) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search, bool fuse_relu)
forward : depthwise_conv2d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search, bool fuse_relu, bool use_gpudnn) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search, bool fuse_relu)
output : Tensor(input_grad), Tensor(filter_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
......
......@@ -454,7 +454,7 @@
backward : conj_grad
- op : conv2d
args : (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
args : (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format)
output : Tensor
infer_meta :
func : ConvInferMeta
......@@ -474,10 +474,10 @@
backward : conv2d_transpose_grad
- op : conv3d
args : (Tensor input, Tensor filter, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
args : (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
output : Tensor
infer_meta :
func : ConvInferMeta
func : Conv3DInferMeta
kernel :
func : conv3d
use_gpudnn : true
......@@ -564,7 +564,7 @@
args : (Tensor x, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search, bool fuse_relu, bool use_gpudnn)
output : Tensor(out)
infer_meta :
func : ConvInferMeta
func : DepthwiseConvInferMeta
param : [x, filter, strides, paddings, padding_algorithm, groups, dilations, data_format, use_addto, workspace_size_MB, exhaustive_search]
kernel :
func : depthwise_conv2d
......
......@@ -23,9 +23,8 @@ limitations under the License. */
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#endif
#ifndef PADDLE_WITH_CUSTOM_KERNEL
// TODO(wilber): DeviceContextPool nees include fluid file.
......
......@@ -24,7 +24,8 @@ limitations under the License. */
namespace phi {
class PADDLE_API CPUContext : public DeviceContext {
class PADDLE_API CPUContext : public DeviceContext,
public TypeInfoTraits<DeviceContext, CPUContext> {
public:
CPUContext();
CPUContext(CPUContext&&);
......@@ -34,6 +35,8 @@ class PADDLE_API CPUContext : public DeviceContext {
Eigen::DefaultDevice* eigen_device() const;
const Place& GetPlace() const override;
static const char* name() { return "CPUContext"; }
protected:
// NOTE: External users manage resources. Used in inference scenarios.
// The Set interface is for inference only, DeviceContext will mark the
......
......@@ -21,7 +21,8 @@ limitations under the License. */
namespace phi {
class CustomContext : public DeviceContext {
class CustomContext : public DeviceContext,
public TypeInfoTraits<DeviceContext, CustomContext> {
public:
explicit CustomContext(const CustomPlace&);
......@@ -35,6 +36,8 @@ class CustomContext : public DeviceContext {
// Wait for all operations completion in the stream.
void Wait() const override;
static const char* name() { return "CustomContext"; }
public:
// NOTE: DeviceContext hold resources. Used in training scenarios.
// The interface used by the training scene, DeviceContext will initialize
......
......@@ -717,6 +717,23 @@ struct GPUContext::Impl {
}
}
bool HasDnnAttr(const std::string& attr_name) const {
return dnn_attrs_.count(attr_name) != 0UL;
}
const Attribute& GetDnnAttr(const std::string& attr_name) const {
auto iter = dnn_attrs_.find(attr_name);
PADDLE_ENFORCE_NE(
iter,
dnn_attrs_.end(),
phi::errors::NotFound("Attribute `%s` is not found in OneDNNContext."));
return iter->second;
}
void SetDnnAttr(const std::string& attr_name, Attribute attr) {
dnn_attrs_[attr_name] = attr;
}
// use one flag for all handles?
// they should be accessed consistently
bool owned_{false};
......@@ -780,8 +797,15 @@ struct GPUContext::Impl {
Allocator* allocator_{nullptr}; // external resource.
// A internal resouce to initinalize eigen_device.
std::unique_ptr<internal::EigenGpuStreamDevice> eigen_stream_{nullptr};
// Holds some attributes only used by the gpudnn kernel calculation
// Because DeviceContext is a global singleton, you need to ensure thread
// safety, use the thread_local variable
static thread_local AttributeMap dnn_attrs_;
};
thread_local AttributeMap GPUContext::Impl::dnn_attrs_ = {};
GPUContext::GPUContext(GPUContext&&) = default;
GPUContext& GPUContext::operator=(GPUContext&&) = default;
......@@ -1000,4 +1024,16 @@ void GPUContext::SetDriverVersion(int val) { impl_->driver_version_ = val; }
void GPUContext::SetRuntimeVersion(int val) { impl_->runtime_version_ = val; }
bool GPUContext::HasDnnAttr(const std::string& attr_name) const {
return impl_->HasDnnAttr(attr_name);
}
const Attribute& GPUContext::GetDnnAttr(const std::string& attr_name) const {
return impl_->GetDnnAttr(attr_name);
}
void GPUContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
return impl_->SetDnnAttr(attr_name, std::move(attr));
}
} // namespace phi
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_helper.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/attribute.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
......@@ -77,7 +78,8 @@ class DnnWorkspaceHandle {
std::unique_ptr<std::mutex> mtx_;
};
class PADDLE_API GPUContext : public DeviceContext {
class PADDLE_API GPUContext : public DeviceContext,
public TypeInfoTraits<DeviceContext, GPUContext> {
public:
explicit GPUContext(const GPUPlace& place, bool init = true);
......@@ -166,6 +168,13 @@ class PADDLE_API GPUContext : public DeviceContext {
void WaitStreamCallback() const;
// Several methods for adapting Dnn-specific attributes
bool HasDnnAttr(const std::string& attr_name) const;
const Attribute& GetDnnAttr(const std::string& attr_name) const;
void SetDnnAttr(const std::string& attr_name, Attribute attr);
static const char* name() { return "GPUContext"; }
public:
/*! \brief Return nccl communicators. */
ncclComm_t nccl_comm() const;
......@@ -250,10 +259,10 @@ class PADDLE_API GPUContext : public DeviceContext {
std::unique_ptr<Impl> impl_;
};
// Note: In order to register the kernel of CUDNN, GPUDNNContext is required.
// Note: In order to register the kernel of CUDNN, DnnContext is required.
// Currently, CUDNN kernel directly uses GPUContext. But if the kernel function
// has the same name, this will lead to duplicate instantiations of GPU kernel
// and GPUDNN kernel function, so if we using GPUDNNContext = GPUContext, we
// and Dnn kernel function, so if we using DnnContext = GPUContext, we
// must use different function name for cudnn kernel
using GPUDNNContext = GPUContext;
......
......@@ -16,9 +16,10 @@
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/expect.h"
namespace phi {
......@@ -284,6 +285,69 @@ struct OneDNNContext::Impl {
return key_it->second;
}
bool HasDnnAttr(const std::string& attr_name) const {
return dnn_attrs_.count(attr_name) != 0UL;
}
const Attribute& GetDnnAttr(const std::string& attr_name) const {
auto iter = dnn_attrs_.find(attr_name);
PADDLE_ENFORCE_NE(
iter,
dnn_attrs_.end(),
phi::errors::NotFound("Attribute `%s` is not found in OneDNNContext."));
return iter->second;
}
void SetDnnAttr(const std::string& attr_name, Attribute attr) {
dnn_attrs_[attr_name] = attr;
}
bool HasDnnInput(const std::string& input_name) const {
return dnn_inputs_.count(input_name) != 0UL;
}
const DenseTensor* GetDnnInput(const std::string& input_name) const {
auto iter = dnn_inputs_.find(input_name);
PADDLE_ENFORCE_NE(
iter,
dnn_inputs_.end(),
phi::errors::NotFound(
"Input DenseTensor `%s` is not found in OneDNNContext."));
return iter->second;
}
void SetDnnInput(const std::string& input_name, const DenseTensor* input) {
dnn_inputs_[input_name] = input;
}
void SetInputsName(const TensorNameMap& inputs_name) {
inputs_name_ = inputs_name;
}
void SetOutputsName(const TensorNameMap& outputs_name) {
outputs_name_ = outputs_name;
}
const std::vector<std::string>& GetInputsName(
const std::string& input) const {
auto it = inputs_name_.find(input);
PADDLE_ENFORCE_NE(it,
inputs_name_.end(),
phi::errors::NotFound(
"OneDnnContext does not have the input %s.", input));
return it->second;
}
const std::vector<std::string>& GetOutputsName(
const std::string& output) const {
auto it = outputs_name_.find(output);
PADDLE_ENFORCE_NE(
it,
outputs_name_.end(),
phi::errors::NotFound("OneDnnContext does not have the output %s.",
output));
return it->second;
}
std::shared_ptr<BlobMap> p_blobmap_;
// Map key is pointer of executor and value is a data(iterator in map) needed
// to erase
......@@ -291,8 +355,35 @@ struct OneDNNContext::Impl {
std::shared_ptr<std::mutex> p_mutex_;
// 0 - clearing is allowed. x > 0 do not clear.
unsigned int block_next_cache_clearing_ = 0;
// Holds some attributes only used by the onednn kernel calculation
// Since original mkldnn op kernel directly adds the operations that require
// fusion to the native kernel operations, and uses the attribute `fuse_xxx`
// to control, for onednn, there will be some attributes that seem to be
// independent of the device are also saved here.
// Here, the operation of fusion needs to be implemented separately as
// a fusion op and kernel, instead of patching it to a basic operation.
// Because DeviceContext is a global singleton, you need to ensure thread
// safety, use the thread_local variable
static thread_local AttributeMap dnn_attrs_;
// For onednn, in addition to extra attrs, there are also extra inputs,
// but the number is small. Hope that the implementation can be optimized
// to remove this member in the future.
static thread_local paddle::flat_hash_map<std::string, const DenseTensor*>
dnn_inputs_;
// Onednn need get input and output's name in current Kernel for generating
// unique_key.
static thread_local TensorNameMap inputs_name_;
static thread_local TensorNameMap outputs_name_;
};
thread_local AttributeMap OneDNNContext::Impl::dnn_attrs_ = {};
thread_local paddle::flat_hash_map<std::string, const DenseTensor*>
OneDNNContext::Impl::dnn_inputs_ = {};
thread_local TensorNameMap OneDNNContext::Impl::inputs_name_ = {};
thread_local TensorNameMap OneDNNContext::Impl::outputs_name_ = {};
OneDNNContext::OneDNNContext(const Place& place)
: CPUContext(place), impl_(std::make_unique<Impl>()) {}
......@@ -322,5 +413,49 @@ OneDNNContext::BlobPtr_t<void> OneDNNContext::GetBlob(
return impl_->GetBlob(name);
}
bool OneDNNContext::HasDnnAttr(const std::string& attr_name) const {
return impl_->HasDnnAttr(attr_name);
}
const Attribute& OneDNNContext::GetDnnAttr(const std::string& attr_name) const {
return impl_->GetDnnAttr(attr_name);
}
void OneDNNContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
return impl_->SetDnnAttr(attr_name, std::move(attr));
}
bool OneDNNContext::HasDnnInput(const std::string& input_name) const {
return impl_->HasDnnInput(input_name);
}
const DenseTensor* OneDNNContext::GetDnnInput(
const std::string& input_name) const {
return impl_->GetDnnInput(input_name);
}
void OneDNNContext::SetDnnInput(const std::string& input_name,
const DenseTensor* input) {
return impl_->SetDnnInput(input_name, input);
}
void OneDNNContext::SetInputsName(const TensorNameMap& inputs_name) {
impl_->SetInputsName(inputs_name);
}
void OneDNNContext::SetOutputsName(const TensorNameMap& outputs_name) {
impl_->SetOutputsName(outputs_name);
}
const std::vector<std::string>& OneDNNContext::GetInputsName(
const std::string& input) const {
return impl_->GetInputsName(input);
}
const std::vector<std::string>& OneDNNContext::GetOutputsName(
const std::string& output) const {
return impl_->GetOutputsName(output);
}
} // namespace phi
#endif
......@@ -20,9 +20,12 @@ limitations under the License. */
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/attribute.h"
namespace phi {
using TensorNameMap = std::map<std::string, std::vector<std::string>>;
class OneDNNContextThreadLocals {
// default mkldnn session id
......@@ -134,6 +137,26 @@ class OneDNNContext : public CPUContext {
return OneDNNContextThreadLocals::fetch();
}
// Several methods for adapting ONEDNN-specific attributes and inputs
bool HasDnnAttr(const std::string& attr_name) const;
const Attribute& GetDnnAttr(const std::string& attr_name) const;
void SetDnnAttr(const std::string& attr_name, Attribute attr);
bool HasDnnInput(const std::string& input_name) const;
const DenseTensor* GetDnnInput(const std::string& input_name) const;
void SetDnnInput(const std::string& input_name, const DenseTensor* input);
void SetInputsName(const TensorNameMap& inputs_name);
void SetOutputsName(const TensorNameMap& outputs_name);
const std::vector<std::string>& GetInputsName(const std::string& input) const;
const std::vector<std::string>& GetOutputsName(
const std::string& output) const;
static const char* name() { return "OneDNNContext"; }
private:
struct Impl;
std::unique_ptr<Impl> impl_;
......
......@@ -195,6 +195,41 @@ inline std::string CreateKey(const OneDNNContext& dev_ctx, ArgTypes&&... args) {
return key;
}
inline std::vector<std::vector<int64_t>> ToOnednnPadding(
const std::vector<int64_t>& paddings) {
if (paddings.size() == 6) {
int padding_front = paddings[0];
int padding_back = paddings[1];
int padding_top = paddings[2];
int padding_bottom = paddings[3];
int padding_left = paddings[4];
int padding_right = paddings[5];
return {{padding_front, padding_top, padding_left},
{padding_back, padding_bottom, padding_right}};
} else {
int padding_top = paddings[0];
int padding_bottom = paddings[1];
int padding_left = paddings[2];
int padding_right = paddings[3];
return {{padding_top, padding_left}, {padding_bottom, padding_right}};
}
}
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
const int groups) {
if (groups > 1) {
// if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
// else [o, i, h, w] -> [g, o/g, i, h, w]
weights_tz.push_back(0);
std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
weights_tz[0] = groups;
weights_tz[1] = weights_tz[1] / groups;
}
}
inline void MatchShapeToLayout(DenseTensor* tensor_in,
DataLayout from,
DataLayout to) {
......
......@@ -39,6 +39,67 @@ using memory = dnnl::memory;
using OneDNNMemoryFormat = dnnl::memory::format_tag;
static void AppendActivation(const OneDNNContext& dev_ctx,
dnnl::post_ops& post_ops, // NOLINT
float activation_scale = 1.0f) {
const auto invalid_attribute =
dev_ctx.HasDnnAttr("fuse_activation")
? PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation"))
.empty()
: true;
if (invalid_attribute) return;
const auto fuse_activation =
dev_ctx.HasDnnAttr("fuse_activation")
? PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation"))
: "";
const auto fuse_alpha =
dev_ctx.HasDnnAttr("fuse_alpha")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fuse_alpha"))
: 0.0f;
const auto fuse_beta =
dev_ctx.HasDnnAttr("fuse_beta")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fuse_beta"))
: 0.0f;
if (fuse_activation == "hard_sigmoid") {
post_ops.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha,
fuse_beta);
post_ops.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else {
const std::unordered_map<std::string, dnnl::algorithm> activation_map = {
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_erf", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"hard_swish", dnnl::algorithm::eltwise_hardswish},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"mish", dnnl::algorithm::eltwise_mish},
{"relu", dnnl::algorithm::eltwise_relu},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"swish", dnnl::algorithm::eltwise_swish},
{"tanh", dnnl::algorithm::eltwise_tanh}};
const auto& activation_type = activation_map.find(fuse_activation);
PADDLE_ENFORCE_NE(
activation_type,
activation_map.end(),
phi::errors::InvalidArgument(
"Activation '%s' not found in oneDNN algorithms mapper",
fuse_activation));
post_ops.append_eltwise(
activation_scale, activation_type->second, fuse_alpha, fuse_beta);
}
}
template <typename T,
typename TForward,
typename TBackward = onednn_dummy_primitive,
......@@ -1085,5 +1146,6 @@ class ClipOneDNNHandler
to_void_cast<T>(input_data));
}
};
} // namespace funcs
} // namespace phi
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_XPU
#include <memory>
#include "paddle/phi/backends/xpu/forwards.h"
......@@ -26,7 +28,8 @@ namespace xpu = baidu::xpu::api;
namespace phi {
class XPUContext : public DeviceContext {
class XPUContext : public DeviceContext,
public TypeInfoTraits<DeviceContext, XPUContext> {
public:
XPUContext();
......@@ -65,6 +68,8 @@ class XPUContext : public DeviceContext {
XPUStream stream() const;
static const char* name() { return "XPUContext"; }
private:
struct Impl;
std::unique_ptr<Impl> impl_;
......@@ -79,3 +84,5 @@ using KPSContext = XPUContext;
#endif
} // namespace phi
#endif
......@@ -48,6 +48,6 @@ using Attribute = paddle::variant<bool,
DataLayout,
Place>;
using RuntimeAttrs = paddle::flat_hash_map<std::string, Attribute>;
using AttributeMap = paddle::flat_hash_map<std::string, Attribute>;
} // namespace phi
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/utils/type_registry.h"
namespace phi {
class TensorBase;
......@@ -188,9 +189,21 @@ class PADDLE_API DeviceContext {
*/
Generator* GetHostGenerator() const;
/**
* @brief Return the type information of the derived class to support
* safely downcast in non-rtti environment.
*
* @return The type information of the derived class.
*/
TypeInfo<DeviceContext> type_info() const { return type_info_; }
private:
struct Impl;
std::unique_ptr<Impl> impl_;
template <typename T, typename U>
friend class TypeInfoTraits;
TypeInfo<DeviceContext> type_info_{TypeInfo<DeviceContext>::kUnknownType};
};
} // namespace phi
......@@ -43,6 +43,7 @@ limitations under the License. */
#include "paddle/phi/core/errors.h"
#include "paddle/utils/string/printf.h"
#include "paddle/utils/string/to_string.h"
#include "paddle/utils/variant.h"
DECLARE_int32(call_stack_level);
......@@ -409,80 +410,75 @@ struct EnforceNotMet : public std::exception {
/** EXTENDED TOOL FUNCTIONS WITH CHECKING **/
/*
* Summary: This macro is used to get Variable or internal type
* data (such as LoDTensor or SelectedRows) of the Input and
* Output in op, generally used when call scope.FindVar(Input/
* Output("Name")) or ctx.Input<LoDTensor>().
* Firstly this macro check whether the obtained pointer is null,
* and then return data if it is not null.
*
* Note: This macro is only suitable for specific scenarios and
* does not intended to be widely used. If it cannot meet the
* requirements, please use other PADDLE_ENFORCE** check macro.
* Summary: This PADDLE_GET(_**) series macros are used to call paddle::get
* safely. paddle::get is not a completely safe api, although it will not
* go wrong in most cases, but in extreme cases, it may fail and directly
* throw a paddle::bad_variant_access const exception, without any stack
*information.
* This kind of problems is difficult to debug, so add these macros to
* enrich paddle::get error information. At the same time, we restrict
* the direct use of paddle::get by CI rule.
*
* Parameters:
*     __PTR: pointer
* __ROLE: (string), Input or Output
* __NAME: (string), Input or Output name
* __OP_TYPE: (string), the op type
*
* Return: The data pointed to by the pointer.
*     __TYPE: the target variable type
* __VALUE: the target variable to get
*
* Examples:
* GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X", "Mul");
* - unsafe writing: int x = paddle::get<int>(y);
* - safe writing: int x = PADDLE_GET(int, y);
*
* Note: GCC 4.8 cannot select right overloaded function here, so need
* to define different functions and macros here, after we upgrade
* CI gcc version, we can only define one PADDLE_GET macro.
*/
#define GET_DATA_SAFELY(__PTR, __ROLE, __NAME, __OP_TYPE) \
(([&]() -> std::add_lvalue_reference<decltype(*(__PTR))>::type { \
auto* __ptr = (__PTR); \
if (UNLIKELY(nullptr == __ptr)) { \
auto __summary__ = phi::errors::NotFound( \
"Unable to get %s data of %s %s in operator %s. " \
"Possible reasons are:\n" \
" 1. The %s is not the %s of operator %s;\n" \
" 2. The %s has no corresponding variable passed in;\n" \
" 3. The %s corresponding variable is not initialized.", \
phi::demangle( \
typeid(std::add_lvalue_reference<decltype(*__ptr)>::type) \
.name()), \
__ROLE, \
__NAME, \
__OP_TYPE, \
__NAME, \
__ROLE, \
__OP_TYPE, \
__NAME, \
__NAME); \
auto __message__ = ::paddle::string::Sprintf( \
"%s\n [Hint: pointer " #__PTR " should not be null.]", \
__summary__.error_message()); \
__THROW_ERROR_INTERNAL__( \
phi::ErrorSummary(__summary__.code(), __message__)); \
namespace details {
#define DEFINE_SAFE_PADDLE_GET( \
__InputType, __OutputType, __OutputTypePtr, __FuncName) \
template <typename OutputType, typename InputType> \
auto __FuncName( \
__InputType input, const char* expression, const char* file, int line) \
->typename std::conditional<std::is_pointer<InputType>::value, \
__OutputTypePtr, \
__OutputType>::type { \
try { \
return paddle::get<OutputType>(input); \
} catch (paddle::bad_variant_access const&) { \
HANDLE_THE_ERROR \
throw ::phi::enforce::EnforceNotMet( \
phi::errors::InvalidArgument( \
"paddle::get failed, cannot get value " \
"(%s) by type %s, its type is %s.", \
expression, \
phi::enforce::demangle(typeid(OutputType).name()), \
phi::enforce::demangle(input.type().name())), \
file, \
line); \
END_HANDLE_THE_ERROR \
} \
return *__ptr; \
})())
}
/*
* Summary: This macro is used to check whether op has specified
* Input or Output Variables. Because op's Input and Output
* checking are written similarly, so abstract this macro.
*
* Parameters:
*     __EXPR: (bool), the bool expression
* __ROLE: (string), Input or Output
* __NAME: (string), Input or Output name
* __OP_TYPE: (string), the op type
*
* Examples:
* OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Mul");
*/
#define OP_INOUT_CHECK(__EXPR, __ROLE, __NAME, __OP_TYPE) \
do { \
PADDLE_ENFORCE_EQ( \
__EXPR, \
true, \
phi::errors::NotFound( \
"No %s(%s) found for %s operator.", __ROLE, __NAME, __OP_TYPE)); \
} while (0)
DEFINE_SAFE_PADDLE_GET(InputType&, OutputType&, OutputType*, SafeBoostGet);
DEFINE_SAFE_PADDLE_GET(const InputType&,
const OutputType&,
const OutputType*,
SafeBoostGetConst);
DEFINE_SAFE_PADDLE_GET(InputType&&,
OutputType,
OutputType*,
SafeBoostGetMutable);
} // namespace details
#define PADDLE_GET(__TYPE, __VALUE) \
phi::enforce::details::SafeBoostGet<__TYPE>( \
__VALUE, #__VALUE, __FILE__, __LINE__)
#define PADDLE_GET_CONST(__TYPE, __VALUE) \
phi::enforce::details::SafeBoostGetConst<__TYPE>( \
__VALUE, #__VALUE, __FILE__, __LINE__)
#define PADDLE_GET_MUTABLE(__TYPE, __VALUE) \
phi::enforce::details::SafeBoostGetMutable<__TYPE>( \
__VALUE, #__VALUE, __FILE__, __LINE__)
} // namespace enforce
using namespace enforce; // NOLINT
......
......@@ -138,8 +138,6 @@ class KernelContext {
template <typename AttrType>
const AttrType& AttrAt(size_t idx) const;
const RuntimeAttrs& GetRuntimeAttrs() const { return runtime_attrs_; }
size_t InputsSize() const { return inputs_.size(); }
size_t OutputsSize() const { return outputs_.size(); }
size_t AttrsSize() const { return attrs_.size(); }
......@@ -161,8 +159,6 @@ class KernelContext {
paddle::small_vector<std::pair<int, int>, kInputSmallVectorSize> input_range_;
paddle::small_vector<std::pair<int, int>, kOutputSmallVectorSize>
output_range_;
RuntimeAttrs runtime_attrs_;
};
} // namespace phi
......@@ -233,8 +233,6 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
args_def->AppendAttribute(AttributeType::DATA_LAYOUT);
} else if (arg_type == std::type_index(typeid(Place))) {
args_def->AppendAttribute(AttributeType::PLACE);
} else if (arg_type == std::type_index(typeid(RuntimeAttrs))) {
// do nothing
} else {
PADDLE_THROW(phi::errors::Unavailable(
"Unsupported kernel argument type `%s`.", arg_type.name()));
......
......@@ -14,13 +14,7 @@
#pragma once
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/xpu/xpu_context.h"
#endif
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
......@@ -330,21 +324,6 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(TensorArray);
template <typename... Tail>
struct KernelCallHelper<const RuntimeAttrs&, Tail...> {
template <int dev_ctx_idx,
int in_idx,
int attr_idx,
int out_idx,
typename... PreviousArgs>
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {
const auto& runtime_attrs = ctx->GetRuntimeAttrs();
KernelCallHelper<Tail...>::
template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx>(
ctx, pargs..., runtime_attrs);
}
};
/* End case */
template <typename T>
struct KernelCallHelper<TypeTag<T>> {
......
......@@ -409,12 +409,9 @@ void ConvInferMeta(const MetaTensor& input,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
int groups,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
MetaTensor* out,
MetaConfig config) {
std::vector<int> paddings = paddings_t;
......@@ -559,27 +556,27 @@ void ConvInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}
void ConvInferInferMeta(const MetaTensor& input,
void Conv3DInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
MetaTensor* out,
MetaConfig config) {
ConvInferMeta(input,
filter,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
groups,
data_format,
/*use_addto=*/false,
/*workspace_size_MB=*/512, // useless in infermeta
/*exhaustive_search=*/false,
out,
config);
}
......@@ -922,6 +919,31 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
loss->share_lod(logits);
}
void DepthwiseConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
MetaTensor* out,
MetaConfig config) {
ConvInferMeta(input,
filter,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
out,
config);
}
void DistInferMeta(const MetaTensor& x,
const MetaTensor& y,
float p,
......@@ -2876,4 +2898,3 @@ void Unpool3dInferMeta(const MetaTensor& x,
} // namespace phi
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
PD_REGISTER_INFER_META_FN(conv2d_infer, phi::ConvInferInferMeta);
......@@ -80,24 +80,24 @@ void ConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ConvInferInferMeta(const MetaTensor& input,
void Conv3DInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
MetaTensor* out,
MetaConfig config = MetaConfig());
......@@ -143,6 +143,20 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
MetaTensor* loss,
MetaConfig config = MetaConfig());
void DepthwiseConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
MetaTensor* out,
MetaConfig config = MetaConfig());
void DistInferMeta(const MetaTensor& x,
const MetaTensor& y,
float p,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ConvGradGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad);
template <typename T, typename Context>
void Conv3DGradGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad);
} // namespace phi
......@@ -25,13 +25,10 @@ void ConvGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* input_grad,
DenseTensor* filter_grad);
......@@ -42,7 +39,7 @@ void Conv3DGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
......@@ -59,7 +56,7 @@ void DepthwiseConvGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
......@@ -70,4 +67,41 @@ void DepthwiseConvGradKernel(const Context& dev_ctx,
DenseTensor* input_grad,
DenseTensor* filter_grad);
template <typename T, typename Context>
void ConvGradGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad);
template <typename T, typename Context>
void Conv3DGradGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/conv_kernel.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ConvInferKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
DenseTensor* out) {
ConvKernel<T, Context>(dev_ctx,
input,
filter,
strides,
paddings,
paddding_algorithm,
groups,
dilations,
data_format,
/*use_addto=*/false,
/*workspace_size_MB=*/
paddle::platform::GetDefaultConvWorkspaceSizeLimitMB(),
/*exhaustive_search=*/false,
out);
}
} // namespace phi
PD_REGISTER_KERNEL(
conv2d_infer, CPU, ALL_LAYOUT, phi::ConvInferKernel, float, double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
conv2d_infer, GPU, ALL_LAYOUT, phi::ConvInferKernel, float, double) {}
#endif
......@@ -25,12 +25,9 @@ void ConvKernel(const Context& dev_ctx,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* out);
template <typename T, typename Context>
......@@ -54,7 +51,7 @@ void DepthwiseConvKernel(const Context& dev_ctx,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
......@@ -64,16 +61,4 @@ void DepthwiseConvKernel(const Context& dev_ctx,
bool fuse_relu,
DenseTensor* out);
template <typename T, typename Context>
void ConvInferKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/conv_grad_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/conv_grad_grad_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void Conv3DGradGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search_t,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad) {
ConvGradGradKernel<T>(ctx,
input,
filter,
out_grad,
input_grad_grad,
filter_grad_grad,
strides,
paddings_t,
padding_algorithm,
groups,
dilations_t,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search_t,
input_grad,
filter_grad,
out_grad_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(
conv2d_grad_grad, CPU, ALL_LAYOUT, phi::ConvGradGradKernel, float, double) {
}
PD_REGISTER_KERNEL(conv3d_grad_grad,
CPU,
ALL_LAYOUT,
phi::Conv3DGradGradKernel,
float,
double) {}
......@@ -27,7 +27,7 @@ void DepthwiseConvGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
......@@ -43,13 +43,10 @@ void DepthwiseConvGradKernel(const Context& dev_ctx,
out_grad,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
input_grad,
filter_grad);
}
......@@ -61,7 +58,7 @@ void Conv3DGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
......@@ -76,17 +73,50 @@ void Conv3DGradKernel(const Context& dev_ctx,
out_grad,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
input_grad,
filter_grad);
}
template <typename T, typename Context>
void Conv3DGradGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search_t,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad) {
ConvGradGradKernel<T>(ctx,
input,
filter,
out_grad,
input_grad_grad,
filter_grad_grad,
strides,
paddings_t,
padding_algorithm,
dilations_t,
groups,
data_format,
input_grad,
filter_grad,
out_grad_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(
......@@ -101,3 +131,14 @@ PD_REGISTER_KERNEL(depthwise_conv2d_grad,
PD_REGISTER_KERNEL(
conv3d_grad, CPU, ALL_LAYOUT, phi::Conv3DGradKernel, float, double) {}
PD_REGISTER_KERNEL(
conv2d_grad_grad, CPU, ALL_LAYOUT, phi::ConvGradGradKernel, float, double) {
}
PD_REGISTER_KERNEL(conv3d_grad_grad,
CPU,
ALL_LAYOUT,
phi::Conv3DGradGradKernel,
float,
double) {}
......@@ -19,6 +19,30 @@
#include "paddle/phi/kernels/impl/conv_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void ConvKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
DenseTensor* out) {
ConvKernelImpl<T>(dev_ctx,
input,
filter,
strides,
paddings,
padding_algorithm,
groups,
dilations,
data_format,
out);
}
template <typename T, typename Context>
void DepthwiseConvKernel(const Context& dev_ctx,
const DenseTensor& input,
......@@ -34,7 +58,7 @@ void DepthwiseConvKernel(const Context& dev_ctx,
bool exhaustive_search,
bool fuse_relu,
DenseTensor* out) {
ConvKernel<T>(dev_ctx,
ConvKernelImpl<T>(dev_ctx,
input,
filter,
strides,
......@@ -43,9 +67,6 @@ void DepthwiseConvKernel(const Context& dev_ctx,
groups,
dilations,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
out);
}
......@@ -63,7 +84,7 @@ void Conv3DKernel(const Context& dev_ctx,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* out) {
ConvKernel<T>(dev_ctx,
ConvKernelImpl<T>(dev_ctx,
input,
filter,
strides,
......@@ -72,9 +93,6 @@ void Conv3DKernel(const Context& dev_ctx,
groups,
dilations,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
out);
}
......
......@@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows
#endif
#include "paddle/phi/kernels/erfinv_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
......
......@@ -12,10 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows
#endif
#include "paddle/phi/kernels/erfinv_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/erfinv_kernel_impl.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
auto eigen_in = EigenVector<T>::Flatten(x);
auto eigen_out = EigenVector<T>::Flatten(*out);
auto& place = *ctx.eigen_device();
constexpr T half = static_cast<T>(0.5);
constexpr T half_sqrt = static_cast<T>(M_SQRT1_2);
eigen_out.device(place) = (eigen_in * half + half).ndtri() * half_sqrt;
}
} // namespace phi
PD_REGISTER_KERNEL(erfinv, CPU, ALL_LAYOUT, phi::ErfinvKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/conv_grad_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/conv_grad_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
conv2d_grad_grad, GPU, ALL_LAYOUT, phi::ConvGradGradKernel, float, double) {
}
......@@ -27,7 +27,7 @@ void Conv3DGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
......@@ -42,13 +42,10 @@ void Conv3DGradKernel(const Context& dev_ctx,
out_grad,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
input_grad,
filter_grad);
}
......@@ -60,3 +57,7 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
conv3d_grad, GPU, ALL_LAYOUT, phi::Conv3DGradKernel, float, double) {}
PD_REGISTER_KERNEL(
conv2d_grad_grad, GPU, ALL_LAYOUT, phi::ConvGradGradKernel, float, double) {
}
......@@ -20,6 +20,29 @@
namespace phi {
template <typename T, typename Context>
void ConvKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
DenseTensor* out) {
ConvKernelImpl<T>(dev_ctx,
input,
filter,
strides,
paddings,
padding_algorithm,
groups,
dilations,
data_format,
out);
}
template <typename T, typename Context>
void Conv3DKernel(const Context& dev_ctx,
const DenseTensor& input,
......@@ -34,7 +57,7 @@ void Conv3DKernel(const Context& dev_ctx,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* out) {
ConvKernel<T>(dev_ctx,
ConvKernelImpl<T>(dev_ctx,
input,
filter,
strides,
......@@ -43,9 +66,6 @@ void Conv3DKernel(const Context& dev_ctx,
groups,
dilations,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
out);
}
......
......@@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows
#endif
#include "paddle/phi/kernels/erfinv_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/conv_grad_grad_kernel.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/phi/kernels/gpudnn/conv_miopen_helper.h"
#else
#include "paddle/phi/kernels/gpudnn/conv_cudnn_v7.h"
#endif
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/padding.h"
#include "paddle/phi/kernels/impl/conv_cudnn_impl.h"
namespace phi {
template <typename T, typename Context>
void ConvCudnnGradGradKernel(
const Context& ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search_t,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad) {
auto X = &input;
auto W = &filter;
auto dO = &out_grad;
auto ddX = input_grad_grad.get_ptr();
auto ddW = filter_grad_grad.get_ptr();
auto ddO = out_grad_grad;
auto dW = filter_grad;
auto dX = input_grad;
if (ddO) {
ctx.template Alloc<T>(ddO);
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, ddO, static_cast<T>(0));
}
if (dW) {
ctx.template Alloc<T>(dW);
}
if (dX) {
ctx.template Alloc<T>(dX);
}
// const T* x = X->data<T>();
const T* dy = dO->data<T>();
const T* w = W->data<T>();
const T* ddx = nullptr;
const T* ddw = nullptr;
T *dw, *dx, *ddy;
dw = dx = ddy = nullptr;
T* transformed_dx = nullptr;
std::vector<int> dilations = dilations_t;
bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_t;
bool deterministic = FLAGS_cudnn_deterministic;
auto exhaustive_deterministic = exhaustive_search && deterministic;
PADDLE_ENFORCE_EQ(exhaustive_deterministic,
false,
phi::errors::InvalidArgument(
"Cann't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time."));
std::vector<int> paddings = paddings_t;
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// transform Tensors to channel first-----------
DenseTensor transformed_X_channel(X->type());
DenseTensor transformed_dO_channel(dO->type());
DenseTensor transformed_ddX_channel(X->type());
DenseTensor transformed_ddO_channel(dO->type());
DenseTensor transformed_dX_channel(X->type());
if (channel_last) {
ResizeToChannelFirst<Context, T>(ctx, X, &transformed_X_channel);
TransToChannelFirst<Context, T>(ctx, X, &transformed_X_channel);
ResizeToChannelFirst<Context, T>(ctx, dO, &transformed_dO_channel);
TransToChannelFirst<Context, T>(ctx, dO, &transformed_dO_channel);
if (ddX) {
ResizeToChannelFirst<Context, T>(ctx, ddX, &transformed_ddX_channel);
TransToChannelFirst<Context, T>(ctx, ddX, &transformed_ddX_channel);
}
if (ddO) {
ResizeToChannelFirst<Context, T>(ctx, ddO, &transformed_ddO_channel);
}
if (dX) {
ResizeToChannelFirst<Context, T>(ctx, dX, &transformed_dX_channel);
ctx.template Alloc<T>(&transformed_dX_channel);
}
} else {
transformed_X_channel = *X;
transformed_dO_channel = *dO;
if (ddX) {
transformed_ddX_channel = *ddX;
}
if (ddO) {
transformed_ddO_channel.ShareDataWith(*ddO);
}
if (dX) {
transformed_dX_channel.ShareDataWith(*dX);
}
}
auto in_dims = transformed_X_channel.dims();
auto filter_dims = W->dims();
DDim in_data_dims = slice_ddim(in_dims, 2, in_dims.size());
DDim filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = funcs::IsSymmetricPadding(paddings, data_dim);
DenseTensor transformed_X(X->type());
DenseTensor transformed_ddX(X->type());
DenseTensor transformed_dX(X->type());
std::vector<int> padding_common(data_dim, 0);
std::vector<int> input_pad(X->dims().size() * 2, 0);
if (!is_sys_pad) {
// get pad
std::vector<int> padding_diff(data_dim);
std::vector<int> new_input_shape_vec(data_dim + 2);
new_input_shape_vec[0] = transformed_X_channel.dims()[0];
new_input_shape_vec[1] = transformed_X_channel.dims()[1];
for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_input_shape_vec[i + 2] =
transformed_X_channel.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
}
DDim new_input_shape(make_ddim(new_input_shape_vec));
transformed_X.Resize(new_input_shape);
transformed_ddX.Resize(new_input_shape);
transformed_dX.Resize(new_input_shape);
ctx.template Alloc<T>(&transformed_X);
if (ddX) {
ctx.template Alloc<T>(&transformed_ddX);
}
if (dX) {
ctx.template Alloc<T>(&transformed_dX);
}
// pad for input
const int rank = X->dims().size();
T pad_value(0.0);
switch (rank) {
case 4: {
funcs::PadFunction<Context, T, 4>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
if (ddX) {
funcs::PadFunction<Context, T, 4>(ctx,
input_pad,
transformed_ddX_channel,
pad_value,
&transformed_ddX);
}
} break;
case 5: {
funcs::PadFunction<Context, T, 5>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
if (ddX) {
funcs::PadFunction<Context, T, 5>(ctx,
input_pad,
transformed_ddX_channel,
pad_value,
&transformed_ddX);
}
} break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"ConvOp only support tensors with 4 or 5 dimensions."));
}
} else {
transformed_X.ShareDataWith(transformed_X_channel);
if (ddX) {
transformed_ddX.ShareDataWith(transformed_ddX_channel);
}
if (dX) {
transformed_dX.ShareDataWith(transformed_dX_channel);
}
if (paddings.size() == data_dim) {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[i];
}
} else {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[2 * i];
}
}
}
const T* x = transformed_X.data<T>();
int iwo_group = groups;
int c_group = 1;
#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1)
iwo_group = 1;
c_group = groups;
groups = 1;
#endif
auto dtype = paddle::platform::CudnnDataType<T>::type;
auto handle = ctx.cudnn_handle();
auto layout = paddle::platform::GetCudnnTensorFormat(
paddle::platform::DataLayout::kNCHW);
ConvArgs args1{handle,
&transformed_ddX,
W,
&transformed_ddO_channel,
strides,
padding_common,
dilations,
dtype,
groups,
paddle::platform::DataLayout::kNCHW};
ConvArgs args2{handle,
&transformed_X,
ddW,
&transformed_ddO_channel,
strides,
padding_common,
dilations,
dtype,
groups,
paddle::platform::DataLayout::kNCHW};
ConvArgs args3{handle,
&transformed_ddX,
dW,
&transformed_dO_channel,
strides,
padding_common,
dilations,
dtype,
groups,
paddle::platform::DataLayout::kNCHW};
ConvArgs args4{handle,
&transformed_dX,
ddW,
&transformed_dO_channel,
strides,
padding_common,
dilations,
dtype,
groups,
paddle::platform::DataLayout::kNCHW};
#ifdef PADDLE_WITH_HIP
SearchResult<miopenConvFwdAlgorithm_t> fwd_result1;
SearchResult<miopenConvFwdAlgorithm_t> fwd_result2;
SearchResult<miopenConvBwdDataAlgorithm_t> data_result;
SearchResult<miopenConvBwdWeightsAlgorithm_t> filter_result;
#else
SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result1;
SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result2;
SearchResult<cudnnConvolutionBwdDataAlgo_t> data_result;
SearchResult<cudnnConvolutionBwdFilterAlgo_t> filter_result;
#endif
// ddo = conv(ddI, W) + conv(I, ddW)
size_t workspace_size = 0;
T* transformed_ddy_channel = nullptr;
if (ddO) {
ddy = ddO->data<T>();
transformed_ddy_channel = transformed_ddO_channel.data<T>();
if (ddX) {
args1.idesc.set(transformed_ddX, iwo_group);
args1.wdesc.set(*W, layout, iwo_group);
args1.odesc.set(transformed_ddO_channel, iwo_group);
args1.cdesc.set(dtype,
padding_common,
strides,
dilations,
paddle::platform::AllowTF32Cudnn(),
c_group);
#ifdef PADDLE_WITH_HIP
using search1 = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = search1::GetWorkspaceSize(args1);
fwd_result1.algo = search1::Find<T>(
args1, exhaustive_search, false, workspace_size, ctx);
#else
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result1 = search1::Find<T>(ctx, args1, exhaustive_search, false);
workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo);
#endif
}
if (ddW) {
ddw = ddW->data<T>();
args2.idesc.set(transformed_X, iwo_group);
args2.wdesc.set(*ddW, layout, iwo_group);
args2.odesc.set(transformed_ddO_channel, iwo_group);
args2.cdesc.set(dtype,
padding_common,
strides,
dilations,
paddle::platform::AllowTF32Cudnn(),
c_group);
#ifdef PADDLE_WITH_HIP
using search2 = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2));
fwd_result2.algo = search2::Find<T>(
args2, exhaustive_search, false, workspace_size, ctx);
#else
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result2 = search2::Find<T>(ctx, args2, exhaustive_search, false);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo));
#endif
}
}
if (dW && ddX) {
dw = dW->data<T>();
args3.idesc.set(transformed_ddX, iwo_group);
args3.wdesc.set(*dW, layout, iwo_group);
args3.odesc.set(transformed_dO_channel, iwo_group);
args3.cdesc.set(dtype,
padding_common,
strides,
dilations,
paddle::platform::AllowTF32Cudnn(),
c_group);
#ifdef PADDLE_WITH_HIP
using search3 = SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3));
filter_result.algo = search3::Find<T>(
args3, exhaustive_search, deterministic, workspace_size, ctx);
#else
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result =
search3::Find<T>(ctx, args3, exhaustive_search, deterministic);
workspace_size = std::max(
workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo));
#endif
}
if (ddW && dX) {
transformed_dx = transformed_dX.data<T>();
args4.idesc.set(transformed_dX, iwo_group);
args4.wdesc.set(*ddW, layout, iwo_group);
args4.odesc.set(transformed_dO_channel, iwo_group);
args4.cdesc.set(dtype,
padding_common,
strides,
dilations,
paddle::platform::AllowTF32Cudnn(),
c_group);
#ifdef PADDLE_WITH_HIP
using search4 = SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4));
data_result.algo = search4::Find<T>(
args4, exhaustive_search, deterministic, workspace_size, ctx);
#else
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_result =
search4::Find<T>(ctx, args4, exhaustive_search, deterministic);
workspace_size = std::max(
workspace_size, search4::GetWorkspaceSize(args4, data_result.algo));
#endif
}
int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(
transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
int o_n, o_c, o_d, o_h, o_w;
GetNCDHW(transformed_dO_channel.dims(),
DataLayout::kNCHW,
&o_n,
&o_c,
&o_d,
&o_h,
&o_w);
int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = W->numel() / groups;
ScalingParamType<T> alpha = 1.0f;
ScalingParamType<T> beta = 0.0f;
// NOTE(zhiqiu): inplace addto is not supportted in double grad yet.
// ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f :
// 0.0f;
// VLOG(4) << "Conv_grad_grad: use_addto = " << ctx.Attr<bool>("use_addto");
auto workspace_handle = ctx.cudnn_workspace_handle();
if (ddO) {
if (ddX) {
ddx = transformed_ddX.data<T>();
#ifdef PADDLE_WITH_HIP
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionForward(
handle,
&alpha,
args1.idesc.desc(),
ddx,
args1.wdesc.desc(),
w,
args1.cdesc.desc(),
fwd_result1.algo,
&beta,
args1.odesc.desc(),
transformed_ddy_channel,
workspace_ptr,
workspace_size));
},
workspace_size);
#else
for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionForward(
handle,
&alpha,
args1.idesc.desc(),
ddx + i * group_offset_in,
args1.wdesc.desc(),
w + i * group_offset_filter,
args1.cdesc.desc(),
fwd_result1.algo,
workspace_ptr,
workspace_size,
&beta,
args1.odesc.desc(),
transformed_ddy_channel + i * group_offset_out));
},
workspace_size);
}
#endif
}
if (ddW) {
#ifdef PADDLE_WITH_HIP
// MIOPEN ONLY support beta to be 0.0f
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionForward(
handle,
&alpha,
args2.idesc.desc(),
x,
args2.wdesc.desc(),
ddw,
args2.cdesc.desc(),
fwd_result2.algo,
&beta,
args2.odesc.desc(),
transformed_ddy_channel,
workspace_ptr,
workspace_size));
},
workspace_size);
#else
for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionForward(
handle,
&alpha,
args2.idesc.desc(),
x + i * group_offset_in,
args2.wdesc.desc(),
ddw + i * group_offset_filter,
args2.cdesc.desc(),
fwd_result2.algo,
workspace_ptr,
workspace_size,
&alpha,
args2.odesc.desc(),
transformed_ddy_channel + i * group_offset_out));
},
workspace_size);
}
#endif
}
if (channel_last) {
TransToChannelLast<Context, T>(ctx, &transformed_ddO_channel, ddO);
}
}
T* transformed_dy_channel = transformed_dO_channel.data<T>();
if (dW && ddX) {
ddx = transformed_ddX.data<T>();
#ifdef PADDLE_WITH_HIP
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionBackwardWeights(
handle,
&alpha,
args3.odesc.desc(),
transformed_dy_channel,
args3.idesc.desc(),
ddx,
args3.cdesc.desc(),
filter_result.algo,
&beta,
args3.wdesc.desc(),
dw,
workspace_ptr,
workspace_size));
},
workspace_size);
#else
for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionBackwardFilter(
handle,
&alpha,
args3.idesc.desc(),
ddx + i * group_offset_in,
args3.odesc.desc(),
transformed_dy_channel + i * group_offset_out,
args3.cdesc.desc(),
filter_result.algo,
workspace_ptr,
workspace_size,
&beta,
args3.wdesc.desc(),
dw + i * group_offset_filter));
},
workspace_size);
}
#endif
}
if (dX && ddW) {
ddw = ddW->data<T>();
#ifdef PADDLE_WITH_HIP
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionBackwardData(
handle,
&alpha,
args4.odesc.desc(),
transformed_dy_channel,
args4.wdesc.desc(),
ddw,
args4.cdesc.desc(),
data_result.algo,
&beta,
args4.idesc.desc(),
transformed_dx,
workspace_ptr,
workspace_size));
},
workspace_size);
#else
for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionBackwardData(
handle,
&alpha,
args4.wdesc.desc(),
ddw + i * group_offset_filter,
args4.odesc.desc(),
transformed_dy_channel + i * group_offset_out,
args4.cdesc.desc(),
data_result.algo,
workspace_ptr,
workspace_size,
&beta,
args4.idesc.desc(),
transformed_dx + i * group_offset_in));
},
workspace_size);
}
#endif
if (!is_sys_pad) {
// reverse padded input
std::vector<int> starts(X->dims().size(), 0);
std::vector<int> axes(X->dims().size(), 0);
for (size_t i = 0; i < X->dims().size(); ++i) {
starts[i] = input_pad[2 * i];
axes[i] = i;
}
if (X->dims().size() == 4) {
RemovePaddingSlice<Context, T, 4>(
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
} else {
RemovePaddingSlice<Context, T, 5>(
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
}
}
if (channel_last) {
TransToChannelLast<Context, T>(ctx, &transformed_dX_channel, dX);
}
}
}
template <typename T, typename Context>
void DepthwiseConvDoubleGradGPUDNNKernel(
const Context& ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search_t,
bool fuse_relu,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad) {
ConvCudnnGradGradKernel<T>(ctx,
input,
filter,
out_grad,
input_grad_grad,
filter_grad_grad,
strides,
paddings_t,
padding_algorithm,
groups,
dilations_t,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search_t,
input_grad,
filter_grad,
out_grad_grad);
}
template <typename T, typename Context>
void Conv3DCudnnGradGradKernel(
const Context& ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search_t,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad) {
ConvCudnnGradGradKernel<T>(ctx,
input,
filter,
out_grad,
input_grad_grad,
filter_grad_grad,
strides,
paddings_t,
padding_algorithm,
groups,
dilations_t,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search_t,
input_grad,
filter_grad,
out_grad_grad);
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(conv2d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::ConvCudnnGradGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(conv3d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::Conv3DCudnnGradGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad,
GPU,
ALL_LAYOUT,
phi::DepthwiseConvDoubleGradGPUDNNKernel,
float,
phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(conv2d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::ConvCudnnGradGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(conv3d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::Conv3DCudnnGradGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad,
GPU,
ALL_LAYOUT,
phi::DepthwiseConvDoubleGradGPUDNNKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(conv2d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::ConvCudnnGradGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(conv3d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::Conv3DCudnnGradGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad,
GPU,
ALL_LAYOUT,
phi::DepthwiseConvDoubleGradGPUDNNKernel,
float,
double,
phi::dtype::float16) {}
#endif
#endif
......@@ -44,12 +44,9 @@ void ConvCudnnGradKernel(const Context& ctx,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
int groups,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search_t,
DenseTensor* input_grad,
DenseTensor* filter_grad) {
if (input_grad) {
......@@ -59,11 +56,25 @@ void ConvCudnnGradKernel(const Context& ctx,
ctx.template Alloc<T>(filter_grad);
}
bool has_use_addto = ctx.HasDnnAttr("use_addto");
VLOG(4) << "GPUContext contains `use_addto`: " << has_use_addto;
bool use_addto = has_use_addto
? PADDLE_GET_CONST(bool, ctx.GetDnnAttr("use_addto"))
: false;
std::vector<int> dilations = dilations_t;
std::vector<int> strides = strides_t;
std::vector<int> paddings = paddings_t;
bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_t;
bool has_exhaustive_search = ctx.HasDnnAttr("exhaustive_search");
VLOG(4) << "GPUContext contains `exhaustive_search`: "
<< has_exhaustive_search;
bool exhaustive_search_attr =
has_exhaustive_search
? PADDLE_GET_CONST(bool, ctx.GetDnnAttr("exhaustive_search"))
: false;
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || exhaustive_search_attr;
bool deterministic = FLAGS_cudnn_deterministic;
auto exhaustive_deterministic = exhaustive_search && deterministic;
PADDLE_ENFORCE_EQ(exhaustive_deterministic,
......@@ -588,7 +599,7 @@ void Conv3DCudnnGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
......@@ -603,13 +614,10 @@ void Conv3DCudnnGradKernel(const Context& dev_ctx,
out_grad,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
input_grad,
filter_grad);
}
......@@ -621,7 +629,7 @@ void DepthwiseConvCudnnGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
......@@ -637,17 +645,717 @@ void DepthwiseConvCudnnGradKernel(const Context& dev_ctx,
out_grad,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
input_grad,
filter_grad);
}
template <typename T, typename Context>
void ConvCudnnGradGradKernel(
const Context& ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
const std::vector<int>& dilations_t,
int groups,
const std::string& data_format,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad) {
auto X = &input;
auto W = &filter;
auto dO = &out_grad;
auto ddX = input_grad_grad.get_ptr();
auto ddW = filter_grad_grad.get_ptr();
auto ddO = out_grad_grad;
auto dW = filter_grad;
auto dX = input_grad;
if (ddO) {
ctx.template Alloc<T>(ddO);
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, ddO, static_cast<T>(0));
}
if (dW) {
ctx.template Alloc<T>(dW);
}
if (dX) {
ctx.template Alloc<T>(dX);
}
// const T* x = X->data<T>();
const T* dy = dO->data<T>();
const T* w = W->data<T>();
const T* ddx = nullptr;
const T* ddw = nullptr;
T *dw, *dx, *ddy;
dw = dx = ddy = nullptr;
T* transformed_dx = nullptr;
std::vector<int> dilations = dilations_t;
bool has_exhaustive_search = ctx.HasDnnAttr("exhaustive_search");
VLOG(4) << "GPUContext contains `exhaustive_search`: "
<< has_exhaustive_search;
bool exhaustive_search_attr =
has_exhaustive_search
? PADDLE_GET_CONST(bool, ctx.GetDnnAttr("exhaustive_search"))
: false;
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || exhaustive_search_attr;
bool deterministic = FLAGS_cudnn_deterministic;
auto exhaustive_deterministic = exhaustive_search && deterministic;
PADDLE_ENFORCE_EQ(exhaustive_deterministic,
false,
phi::errors::InvalidArgument(
"Cann't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time."));
std::vector<int> paddings = paddings_t;
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// transform Tensors to channel first-----------
DenseTensor transformed_X_channel(X->type());
DenseTensor transformed_dO_channel(dO->type());
DenseTensor transformed_ddX_channel(X->type());
DenseTensor transformed_ddO_channel(dO->type());
DenseTensor transformed_dX_channel(X->type());
if (channel_last) {
ResizeToChannelFirst<Context, T>(ctx, X, &transformed_X_channel);
TransToChannelFirst<Context, T>(ctx, X, &transformed_X_channel);
ResizeToChannelFirst<Context, T>(ctx, dO, &transformed_dO_channel);
TransToChannelFirst<Context, T>(ctx, dO, &transformed_dO_channel);
if (ddX) {
ResizeToChannelFirst<Context, T>(ctx, ddX, &transformed_ddX_channel);
TransToChannelFirst<Context, T>(ctx, ddX, &transformed_ddX_channel);
}
if (ddO) {
ResizeToChannelFirst<Context, T>(ctx, ddO, &transformed_ddO_channel);
}
if (dX) {
ResizeToChannelFirst<Context, T>(ctx, dX, &transformed_dX_channel);
ctx.template Alloc<T>(&transformed_dX_channel);
}
} else {
transformed_X_channel = *X;
transformed_dO_channel = *dO;
if (ddX) {
transformed_ddX_channel = *ddX;
}
if (ddO) {
transformed_ddO_channel.ShareDataWith(*ddO);
}
if (dX) {
transformed_dX_channel.ShareDataWith(*dX);
}
}
auto in_dims = transformed_X_channel.dims();
auto filter_dims = W->dims();
DDim in_data_dims = slice_ddim(in_dims, 2, in_dims.size());
DDim filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = funcs::IsSymmetricPadding(paddings, data_dim);
DenseTensor transformed_X(X->type());
DenseTensor transformed_ddX(X->type());
DenseTensor transformed_dX(X->type());
std::vector<int> padding_common(data_dim, 0);
std::vector<int> input_pad(X->dims().size() * 2, 0);
if (!is_sys_pad) {
// get pad
std::vector<int> padding_diff(data_dim);
std::vector<int> new_input_shape_vec(data_dim + 2);
new_input_shape_vec[0] = transformed_X_channel.dims()[0];
new_input_shape_vec[1] = transformed_X_channel.dims()[1];
for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_input_shape_vec[i + 2] =
transformed_X_channel.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
}
DDim new_input_shape(make_ddim(new_input_shape_vec));
transformed_X.Resize(new_input_shape);
transformed_ddX.Resize(new_input_shape);
transformed_dX.Resize(new_input_shape);
ctx.template Alloc<T>(&transformed_X);
if (ddX) {
ctx.template Alloc<T>(&transformed_ddX);
}
if (dX) {
ctx.template Alloc<T>(&transformed_dX);
}
// pad for input
const int rank = X->dims().size();
T pad_value(0.0);
switch (rank) {
case 4: {
funcs::PadFunction<Context, T, 4>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
if (ddX) {
funcs::PadFunction<Context, T, 4>(ctx,
input_pad,
transformed_ddX_channel,
pad_value,
&transformed_ddX);
}
} break;
case 5: {
funcs::PadFunction<Context, T, 5>(
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
if (ddX) {
funcs::PadFunction<Context, T, 5>(ctx,
input_pad,
transformed_ddX_channel,
pad_value,
&transformed_ddX);
}
} break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"ConvOp only support tensors with 4 or 5 dimensions."));
}
} else {
transformed_X.ShareDataWith(transformed_X_channel);
if (ddX) {
transformed_ddX.ShareDataWith(transformed_ddX_channel);
}
if (dX) {
transformed_dX.ShareDataWith(transformed_dX_channel);
}
if (paddings.size() == data_dim) {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[i];
}
} else {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[2 * i];
}
}
}
const T* x = transformed_X.data<T>();
int iwo_group = groups;
int c_group = 1;
#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1)
iwo_group = 1;
c_group = groups;
groups = 1;
#endif
auto dtype = paddle::platform::CudnnDataType<T>::type;
auto handle = ctx.cudnn_handle();
auto layout = paddle::platform::GetCudnnTensorFormat(
paddle::platform::DataLayout::kNCHW);
ConvArgs args1{handle,
&transformed_ddX,
W,
&transformed_ddO_channel,
strides,
padding_common,
dilations,
dtype,
groups,
paddle::platform::DataLayout::kNCHW};
ConvArgs args2{handle,
&transformed_X,
ddW,
&transformed_ddO_channel,
strides,
padding_common,
dilations,
dtype,
groups,
paddle::platform::DataLayout::kNCHW};
ConvArgs args3{handle,
&transformed_ddX,
dW,
&transformed_dO_channel,
strides,
padding_common,
dilations,
dtype,
groups,
paddle::platform::DataLayout::kNCHW};
ConvArgs args4{handle,
&transformed_dX,
ddW,
&transformed_dO_channel,
strides,
padding_common,
dilations,
dtype,
groups,
paddle::platform::DataLayout::kNCHW};
#ifdef PADDLE_WITH_HIP
SearchResult<miopenConvFwdAlgorithm_t> fwd_result1;
SearchResult<miopenConvFwdAlgorithm_t> fwd_result2;
SearchResult<miopenConvBwdDataAlgorithm_t> data_result;
SearchResult<miopenConvBwdWeightsAlgorithm_t> filter_result;
#else
SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result1;
SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result2;
SearchResult<cudnnConvolutionBwdDataAlgo_t> data_result;
SearchResult<cudnnConvolutionBwdFilterAlgo_t> filter_result;
#endif
// ddo = conv(ddI, W) + conv(I, ddW)
size_t workspace_size = 0;
T* transformed_ddy_channel = nullptr;
if (ddO) {
ddy = ddO->data<T>();
transformed_ddy_channel = transformed_ddO_channel.data<T>();
if (ddX) {
args1.idesc.set(transformed_ddX, iwo_group);
args1.wdesc.set(*W, layout, iwo_group);
args1.odesc.set(transformed_ddO_channel, iwo_group);
args1.cdesc.set(dtype,
padding_common,
strides,
dilations,
paddle::platform::AllowTF32Cudnn(),
c_group);
#ifdef PADDLE_WITH_HIP
using search1 = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = search1::GetWorkspaceSize(args1);
fwd_result1.algo = search1::Find<T>(
args1, exhaustive_search, false, workspace_size, ctx);
#else
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result1 = search1::Find<T>(ctx, args1, exhaustive_search, false);
workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo);
#endif
}
if (ddW) {
ddw = ddW->data<T>();
args2.idesc.set(transformed_X, iwo_group);
args2.wdesc.set(*ddW, layout, iwo_group);
args2.odesc.set(transformed_ddO_channel, iwo_group);
args2.cdesc.set(dtype,
padding_common,
strides,
dilations,
paddle::platform::AllowTF32Cudnn(),
c_group);
#ifdef PADDLE_WITH_HIP
using search2 = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2));
fwd_result2.algo = search2::Find<T>(
args2, exhaustive_search, false, workspace_size, ctx);
#else
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result2 = search2::Find<T>(ctx, args2, exhaustive_search, false);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo));
#endif
}
}
if (dW && ddX) {
dw = dW->data<T>();
args3.idesc.set(transformed_ddX, iwo_group);
args3.wdesc.set(*dW, layout, iwo_group);
args3.odesc.set(transformed_dO_channel, iwo_group);
args3.cdesc.set(dtype,
padding_common,
strides,
dilations,
paddle::platform::AllowTF32Cudnn(),
c_group);
#ifdef PADDLE_WITH_HIP
using search3 = SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3));
filter_result.algo = search3::Find<T>(
args3, exhaustive_search, deterministic, workspace_size, ctx);
#else
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result =
search3::Find<T>(ctx, args3, exhaustive_search, deterministic);
workspace_size = std::max(
workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo));
#endif
}
if (ddW && dX) {
transformed_dx = transformed_dX.data<T>();
args4.idesc.set(transformed_dX, iwo_group);
args4.wdesc.set(*ddW, layout, iwo_group);
args4.odesc.set(transformed_dO_channel, iwo_group);
args4.cdesc.set(dtype,
padding_common,
strides,
dilations,
paddle::platform::AllowTF32Cudnn(),
c_group);
#ifdef PADDLE_WITH_HIP
using search4 = SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4));
data_result.algo = search4::Find<T>(
args4, exhaustive_search, deterministic, workspace_size, ctx);
#else
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_result =
search4::Find<T>(ctx, args4, exhaustive_search, deterministic);
workspace_size = std::max(
workspace_size, search4::GetWorkspaceSize(args4, data_result.algo));
#endif
}
int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(
transformed_X.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
int o_n, o_c, o_d, o_h, o_w;
GetNCDHW(transformed_dO_channel.dims(),
DataLayout::kNCHW,
&o_n,
&o_c,
&o_d,
&o_h,
&o_w);
int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = W->numel() / groups;
ScalingParamType<T> alpha = 1.0f;
ScalingParamType<T> beta = 0.0f;
// NOTE(zhiqiu): inplace addto is not supportted in double grad yet.
// ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f :
// 0.0f;
// VLOG(4) << "Conv_grad_grad: use_addto = " << ctx.Attr<bool>("use_addto");
auto workspace_handle = ctx.cudnn_workspace_handle();
if (ddO) {
if (ddX) {
ddx = transformed_ddX.data<T>();
#ifdef PADDLE_WITH_HIP
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionForward(
handle,
&alpha,
args1.idesc.desc(),
ddx,
args1.wdesc.desc(),
w,
args1.cdesc.desc(),
fwd_result1.algo,
&beta,
args1.odesc.desc(),
transformed_ddy_channel,
workspace_ptr,
workspace_size));
},
workspace_size);
#else
for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionForward(
handle,
&alpha,
args1.idesc.desc(),
ddx + i * group_offset_in,
args1.wdesc.desc(),
w + i * group_offset_filter,
args1.cdesc.desc(),
fwd_result1.algo,
workspace_ptr,
workspace_size,
&beta,
args1.odesc.desc(),
transformed_ddy_channel + i * group_offset_out));
},
workspace_size);
}
#endif
}
if (ddW) {
#ifdef PADDLE_WITH_HIP
// MIOPEN ONLY support beta to be 0.0f
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionForward(
handle,
&alpha,
args2.idesc.desc(),
x,
args2.wdesc.desc(),
ddw,
args2.cdesc.desc(),
fwd_result2.algo,
&beta,
args2.odesc.desc(),
transformed_ddy_channel,
workspace_ptr,
workspace_size));
},
workspace_size);
#else
for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionForward(
handle,
&alpha,
args2.idesc.desc(),
x + i * group_offset_in,
args2.wdesc.desc(),
ddw + i * group_offset_filter,
args2.cdesc.desc(),
fwd_result2.algo,
workspace_ptr,
workspace_size,
&alpha,
args2.odesc.desc(),
transformed_ddy_channel + i * group_offset_out));
},
workspace_size);
}
#endif
}
if (channel_last) {
TransToChannelLast<Context, T>(ctx, &transformed_ddO_channel, ddO);
}
}
T* transformed_dy_channel = transformed_dO_channel.data<T>();
if (dW && ddX) {
ddx = transformed_ddX.data<T>();
#ifdef PADDLE_WITH_HIP
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionBackwardWeights(
handle,
&alpha,
args3.odesc.desc(),
transformed_dy_channel,
args3.idesc.desc(),
ddx,
args3.cdesc.desc(),
filter_result.algo,
&beta,
args3.wdesc.desc(),
dw,
workspace_ptr,
workspace_size));
},
workspace_size);
#else
for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionBackwardFilter(
handle,
&alpha,
args3.idesc.desc(),
ddx + i * group_offset_in,
args3.odesc.desc(),
transformed_dy_channel + i * group_offset_out,
args3.cdesc.desc(),
filter_result.algo,
workspace_ptr,
workspace_size,
&beta,
args3.wdesc.desc(),
dw + i * group_offset_filter));
},
workspace_size);
}
#endif
}
if (dX && ddW) {
ddw = ddW->data<T>();
#ifdef PADDLE_WITH_HIP
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionBackwardData(
handle,
&alpha,
args4.odesc.desc(),
transformed_dy_channel,
args4.wdesc.desc(),
ddw,
args4.cdesc.desc(),
data_result.algo,
&beta,
args4.idesc.desc(),
transformed_dx,
workspace_ptr,
workspace_size));
},
workspace_size);
#else
for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc(
[&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionBackwardData(
handle,
&alpha,
args4.wdesc.desc(),
ddw + i * group_offset_filter,
args4.odesc.desc(),
transformed_dy_channel + i * group_offset_out,
args4.cdesc.desc(),
data_result.algo,
workspace_ptr,
workspace_size,
&beta,
args4.idesc.desc(),
transformed_dx + i * group_offset_in));
},
workspace_size);
}
#endif
if (!is_sys_pad) {
// reverse padded input
std::vector<int> starts(X->dims().size(), 0);
std::vector<int> axes(X->dims().size(), 0);
for (size_t i = 0; i < X->dims().size(); ++i) {
starts[i] = input_pad[2 * i];
axes[i] = i;
}
if (X->dims().size() == 4) {
RemovePaddingSlice<Context, T, 4>(
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
} else {
RemovePaddingSlice<Context, T, 5>(
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
}
}
if (channel_last) {
TransToChannelLast<Context, T>(ctx, &transformed_dX_channel, dX);
}
}
}
template <typename T, typename Context>
void DepthwiseConvDoubleGradGPUDNNKernel(
const Context& ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search_t,
bool fuse_relu,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad) {
ConvCudnnGradGradKernel<T>(ctx,
input,
filter,
out_grad,
input_grad_grad,
filter_grad_grad,
strides,
paddings_t,
padding_algorithm,
dilations_t,
groups,
data_format,
input_grad,
filter_grad,
out_grad_grad);
}
template <typename T, typename Context>
void Conv3DCudnnGradGradKernel(
const Context& ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search_t,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad) {
ConvCudnnGradGradKernel<T>(ctx,
input,
filter,
out_grad,
input_grad_grad,
filter_grad_grad,
strides,
paddings_t,
padding_algorithm,
dilations_t,
groups,
data_format,
input_grad,
filter_grad,
out_grad_grad);
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
......@@ -671,6 +1379,26 @@ PD_REGISTER_KERNEL(depthwise_conv2d_grad,
phi::DepthwiseConvCudnnGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(conv2d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::ConvCudnnGradGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(conv3d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::Conv3DCudnnGradGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad,
GPU,
ALL_LAYOUT,
phi::DepthwiseConvDoubleGradGPUDNNKernel,
float,
phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(conv2d_grad,
......@@ -690,6 +1418,32 @@ PD_REGISTER_KERNEL(conv3d_grad,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(conv2d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::ConvCudnnGradGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(conv3d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::Conv3DCudnnGradGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad,
GPU,
ALL_LAYOUT,
phi::DepthwiseConvDoubleGradGPUDNNKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(conv2d_grad,
GPUDNN,
......@@ -707,6 +1461,29 @@ PD_REGISTER_KERNEL(conv3d_grad,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(conv2d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::ConvCudnnGradGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(conv3d_grad_grad,
GPUDNN,
ALL_LAYOUT,
phi::Conv3DCudnnGradGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(depthwise_conv2d_grad_grad,
GPU,
ALL_LAYOUT,
phi::DepthwiseConvDoubleGradGPUDNNKernel,
float,
double,
phi::dtype::float16) {}
#endif
#endif
......@@ -42,18 +42,23 @@ void ConvCudnnKernel(const Context& ctx,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
int groups,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search_t,
DenseTensor* output) {
ctx.template Alloc<T>(output);
std::vector<int> paddings = paddings_t;
std::vector<int> dilations = dilations_t;
bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_t;
bool has_exhaustive_search = ctx.HasDnnAttr("exhaustive_search");
VLOG(4) << "GPUContext contains `exhaustive_search`: "
<< has_exhaustive_search;
bool exhaustive_search_attr =
has_exhaustive_search
? PADDLE_GET_CONST(bool, ctx.GetDnnAttr("exhaustive_search"))
: false;
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || exhaustive_search_attr;
bool deterministic = FLAGS_cudnn_deterministic;
PADDLE_ENFORCE_EQ(exhaustive_search && deterministic,
false,
......@@ -402,12 +407,9 @@ void Conv3DCudnnKernel(const Context& dev_ctx,
strides,
paddings,
padding_algorithm,
groups,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
out);
}
......@@ -432,12 +434,9 @@ void DepthwiseConvCudnnKernel(const Context& dev_ctx,
strides,
paddings,
padding_algorithm,
groups,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
out);
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/conv_kernel.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void ConvGradGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad) {
const DenseTensor* X = &input;
const DenseTensor* dY = &out_grad;
const DenseTensor* ddX = input_grad_grad.get_ptr();
const DenseTensor* ddW_in = filter_grad_grad.get_ptr();
DenseTensor* ddY = out_grad_grad;
DenseTensor* dW = filter_grad;
DenseTensor* dX = input_grad;
DenseTensor W = filter;
if (!ddY && !dW && !dX) return;
const std::vector<int> strides = strides_t;
std::vector<int> paddings = paddings_t;
std::vector<int> dilations = dilations_t;
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// transform Tensor
DenseTensor transformed_X(X->type());
DenseTensor transformed_dY(dY->type());
DenseTensor transformed_ddX(X->type());
if (channel_last) {
ResizeToChannelFirst<Context, T>(dev_ctx, X, &transformed_X);
TransToChannelFirst<Context, T>(dev_ctx, X, &transformed_X);
ResizeToChannelFirst<Context, T>(dev_ctx, dY, &transformed_dY);
TransToChannelFirst<Context, T>(dev_ctx, dY, &transformed_dY);
if (ddX) {
ResizeToChannelFirst<Context, T>(dev_ctx, ddX, &transformed_ddX);
TransToChannelFirst<Context, T>(dev_ctx, ddX, &transformed_ddX);
}
} else {
transformed_X = *X;
transformed_dY = *dY;
if (ddX) {
transformed_ddX = *ddX;
}
}
// update padding and dilation
auto in_dims = transformed_X.dims();
auto filter_dims = W.dims();
DDim in_data_dims = slice_ddim(in_dims, 2, in_dims.size());
DDim filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(transformed_X.dims()[0]);
std::vector<int64_t> filter_shape_vec(vectorize(W.dims()));
std::vector<int64_t> output_shape_vec(vectorize(transformed_dY.dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
// col_shape [in_channel/group, kh, kw, oh, ow]
col_shape_vec[0] = transformed_X.dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + data_dim + 1] = output_shape_vec[j + 2];
}
DDim col_shape(make_ddim(col_shape_vec));
// col_matrix_shape [in_channel/group * kh * kw, oh * ow]
DDim col_matrix_shape = flatten_to_2d(col_shape, data_dim + 1);
// input_shape [Cin, H, W]
DDim input_shape =
slice_ddim(transformed_X.dims(), 1, transformed_X.dims().size());
// filter_matrix_shape [Cout, Cin * kh * kw]
DDim filter_matrix_shape = {W.dims()[0], W.numel() / W.dims()[0]};
W.Resize(filter_matrix_shape);
DDim output_matrix_shape = {
transformed_dY.dims()[1],
transformed_dY.numel() /
(transformed_dY.dims()[0] * transformed_dY.dims()[1])};
int in_step = static_cast<int>(transformed_X.dims()[1]) / groups;
int out_step = static_cast<int>(transformed_dY.dims()[1]) / groups;
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
DenseTensor col;
DenseTensor col_matrix;
if (is_expand) {
col.Resize(col_shape);
dev_ctx.template Alloc<T>(&col);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
phi::funcs::SetConstant<Context, T> set_zero;
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
// dx convolution double grad: gemm + col2im(col2vol)
// dx = ddw * dy ==> dx(N, Cin, H, W), ddw(Cout, Cin, kh, kw), dy(N, Cout,
// oH, oW)
if (dX && ddW_in) {
Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
dev_ctx.template Alloc<T>(dX);
DenseTensor transformed_dX(dX->type());
if (channel_last) {
ResizeToChannelFirst<Context, T>(dev_ctx, dX, &transformed_dX);
} else {
transformed_dX = *dX;
}
// if is_expand is false, the operation of set_zero is unnecessary
// because math::matmul will reset dx
if (is_expand) {
set_zero(dev_ctx, &transformed_dX, static_cast<T>(0));
}
paddle::operators::math::Col2VolFunctor<Context, T> col2vol;
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
for (int i = 0; i < batch_size; i++) {
DenseTensor dy_batch =
transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
DenseTensor dx_batch = transformed_dX.Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
DenseTensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step);
DenseTensor ddw_slice = ddW.Slice(g * out_step, (g + 1) * out_step);
DenseTensor dx_slice = dx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col_matrix.ShareDataWith(dx_slice);
col_matrix.Resize(col_matrix_shape);
}
blas.MatMul(
ddw_slice, true, dy_slice, false, T(1.0), &col_matrix, T(0.0));
if (is_expand && data_dim == 2U) {
col2im(dev_ctx,
col,
dilations,
strides,
std::vector<int>{
paddings[0], paddings[2], paddings[1], paddings[3]},
&dx_slice);
} else if (is_expand && data_dim == 3U) {
col2vol(dev_ctx, col, dilations, strides, paddings, &dx_slice);
}
}
}
if (channel_last) {
TransToChannelLast<Context, T>(dev_ctx, &transformed_dX, dX);
}
}
// dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout,
// oH, oW)
// dw convolution double grad: im2col(vol2col) + gemm
if (dW && ddX) {
dev_ctx.template Alloc<T>(dW);
set_zero(dev_ctx, dW, static_cast<T>(0));
DenseTensor dW_arr = *dW;
dW_arr.Resize(filter_matrix_shape);
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
DenseTensor dy_batch =
transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
Tensor ddx_batch = transformed_ddX.Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; ++g) {
// im2col
DenseTensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step);
DenseTensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(ddx_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
im2col(dev_ctx,
ddx_slice,
dilations,
strides,
std::vector<int>{
paddings[0], paddings[2], paddings[1], paddings[3]},
&col);
} else if (data_dim == 3U) {
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
}
DenseTensor dw_slice = dW_arr.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(
dy_slice, false, col_matrix, true, T(1.0), &dw_slice, T(1.0));
}
}
}
// ddy = w * ddx + x * ddw ==> ddy(N, Cout, oH, oW), x/ddx(N, Cin, H, W),
// w/ddw(Cout, Cin, kh, kw)
// ddy convolution double grad: im2col(vol2col) + gemm
if (ddY) {
dev_ctx.template Alloc<T>(ddY);
DenseTensor transformed_ddY(ddY->type());
if (channel_last) {
ResizeToChannelFirst<Context, T>(dev_ctx, ddY, &transformed_ddY);
} else {
transformed_ddY = *ddY;
}
set_zero(dev_ctx, &transformed_ddY, static_cast<T>(0));
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
DenseTensor ddy_batch =
transformed_ddY.Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; ++g) {
// gemm
DenseTensor ddy_slice =
ddy_batch.Slice(g * out_step, (g + 1) * out_step);
if (ddX) {
DenseTensor ddx_batch =
transformed_ddX.Slice(i, i + 1).Resize(input_shape);
DenseTensor ddx_slice =
ddx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(ddx_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
im2col(dev_ctx,
ddx_slice,
dilations,
strides,
std::vector<int>{
paddings[0], paddings[2], paddings[1], paddings[3]},
&col);
} else if (data_dim == 3U) {
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
}
DenseTensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(
w_slice, false, col_matrix, false, T(1.0), &ddy_slice, T(0.0));
}
if (ddW_in) {
DenseTensor x_batch =
transformed_X.Slice(i, i + 1).Resize(input_shape);
DenseTensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
DenseTensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
if (!is_expand) {
col.ShareDataWith(x_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
im2col(dev_ctx,
x_slice,
dilations,
strides,
std::vector<int>{
paddings[0], paddings[2], paddings[1], paddings[3]},
&col);
} else if (data_dim == 3U) {
vol2col(dev_ctx, x_slice, dilations, strides, paddings, &col);
}
// gemm
DenseTensor ddw_slice = ddW.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(
ddw_slice, false, col_matrix, false, T(1.0), &ddy_slice, T(1.0));
}
}
}
if (channel_last) {
TransToChannelLast<Context, T>(dev_ctx, &transformed_ddY, ddY);
}
}
}
} // namespace phi
......@@ -16,7 +16,6 @@
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/phi/kernels/conv_grad_kernel.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
......@@ -32,12 +31,9 @@ void ConvGradKernel(const Context& dev_ctx,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
int groups,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* input_grad,
DenseTensor* filter_grad) {
// The filter and filter_grad will be reshaped in the calculations,
......@@ -254,4 +250,304 @@ void ConvGradKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void ConvGradGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const paddle::optional<DenseTensor>& input_grad_grad,
const paddle::optional<DenseTensor>& filter_grad_grad,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
const std::vector<int>& dilations_t,
int groups,
const std::string& data_format,
DenseTensor* input_grad,
DenseTensor* filter_grad,
DenseTensor* out_grad_grad) {
const DenseTensor* X = &input;
const DenseTensor* dY = &out_grad;
const DenseTensor* ddX = input_grad_grad.get_ptr();
const DenseTensor* ddW_in = filter_grad_grad.get_ptr();
DenseTensor* ddY = out_grad_grad;
DenseTensor* dW = filter_grad;
DenseTensor* dX = input_grad;
DenseTensor W = filter;
if (!ddY && !dW && !dX) return;
const std::vector<int> strides = strides_t;
std::vector<int> paddings = paddings_t;
std::vector<int> dilations = dilations_t;
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// transform Tensor
DenseTensor transformed_X(X->type());
DenseTensor transformed_dY(dY->type());
DenseTensor transformed_ddX(X->type());
if (channel_last) {
ResizeToChannelFirst<Context, T>(dev_ctx, X, &transformed_X);
TransToChannelFirst<Context, T>(dev_ctx, X, &transformed_X);
ResizeToChannelFirst<Context, T>(dev_ctx, dY, &transformed_dY);
TransToChannelFirst<Context, T>(dev_ctx, dY, &transformed_dY);
if (ddX) {
ResizeToChannelFirst<Context, T>(dev_ctx, ddX, &transformed_ddX);
TransToChannelFirst<Context, T>(dev_ctx, ddX, &transformed_ddX);
}
} else {
transformed_X = *X;
transformed_dY = *dY;
if (ddX) {
transformed_ddX = *ddX;
}
}
// update padding and dilation
auto in_dims = transformed_X.dims();
auto filter_dims = W.dims();
DDim in_data_dims = slice_ddim(in_dims, 2, in_dims.size());
DDim filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(transformed_X.dims()[0]);
std::vector<int64_t> filter_shape_vec(vectorize(W.dims()));
std::vector<int64_t> output_shape_vec(vectorize(transformed_dY.dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
// col_shape [in_channel/group, kh, kw, oh, ow]
col_shape_vec[0] = transformed_X.dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + data_dim + 1] = output_shape_vec[j + 2];
}
DDim col_shape(make_ddim(col_shape_vec));
// col_matrix_shape [in_channel/group * kh * kw, oh * ow]
DDim col_matrix_shape = flatten_to_2d(col_shape, data_dim + 1);
// input_shape [Cin, H, W]
DDim input_shape =
slice_ddim(transformed_X.dims(), 1, transformed_X.dims().size());
// filter_matrix_shape [Cout, Cin * kh * kw]
DDim filter_matrix_shape = {W.dims()[0], W.numel() / W.dims()[0]};
W.Resize(filter_matrix_shape);
DDim output_matrix_shape = {
transformed_dY.dims()[1],
transformed_dY.numel() /
(transformed_dY.dims()[0] * transformed_dY.dims()[1])};
int in_step = static_cast<int>(transformed_X.dims()[1]) / groups;
int out_step = static_cast<int>(transformed_dY.dims()[1]) / groups;
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
DenseTensor col;
DenseTensor col_matrix;
if (is_expand) {
col.Resize(col_shape);
dev_ctx.template Alloc<T>(&col);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
phi::funcs::SetConstant<Context, T> set_zero;
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
// dx convolution double grad: gemm + col2im(col2vol)
// dx = ddw * dy ==> dx(N, Cin, H, W), ddw(Cout, Cin, kh, kw), dy(N, Cout,
// oH, oW)
if (dX && ddW_in) {
Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
dev_ctx.template Alloc<T>(dX);
DenseTensor transformed_dX(dX->type());
if (channel_last) {
ResizeToChannelFirst<Context, T>(dev_ctx, dX, &transformed_dX);
} else {
transformed_dX = *dX;
}
// if is_expand is false, the operation of set_zero is unnecessary
// because math::matmul will reset dx
if (is_expand) {
set_zero(dev_ctx, &transformed_dX, static_cast<T>(0));
}
paddle::operators::math::Col2VolFunctor<Context, T> col2vol;
paddle::operators::math::
Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
col2im;
for (int i = 0; i < batch_size; i++) {
DenseTensor dy_batch =
transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
DenseTensor dx_batch = transformed_dX.Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
DenseTensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step);
DenseTensor ddw_slice = ddW.Slice(g * out_step, (g + 1) * out_step);
DenseTensor dx_slice = dx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col_matrix.ShareDataWith(dx_slice);
col_matrix.Resize(col_matrix_shape);
}
blas.MatMul(
ddw_slice, true, dy_slice, false, T(1.0), &col_matrix, T(0.0));
if (is_expand && data_dim == 2U) {
col2im(dev_ctx,
col,
dilations,
strides,
std::vector<int>{
paddings[0], paddings[2], paddings[1], paddings[3]},
&dx_slice);
} else if (is_expand && data_dim == 3U) {
col2vol(dev_ctx, col, dilations, strides, paddings, &dx_slice);
}
}
}
if (channel_last) {
TransToChannelLast<Context, T>(dev_ctx, &transformed_dX, dX);
}
}
// dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout,
// oH, oW)
// dw convolution double grad: im2col(vol2col) + gemm
if (dW && ddX) {
dev_ctx.template Alloc<T>(dW);
set_zero(dev_ctx, dW, static_cast<T>(0));
DenseTensor dW_arr = *dW;
dW_arr.Resize(filter_matrix_shape);
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
DenseTensor dy_batch =
transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
Tensor ddx_batch = transformed_ddX.Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; ++g) {
// im2col
DenseTensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step);
DenseTensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(ddx_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
im2col(dev_ctx,
ddx_slice,
dilations,
strides,
std::vector<int>{
paddings[0], paddings[2], paddings[1], paddings[3]},
&col);
} else if (data_dim == 3U) {
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
}
DenseTensor dw_slice = dW_arr.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(
dy_slice, false, col_matrix, true, T(1.0), &dw_slice, T(1.0));
}
}
}
// ddy = w * ddx + x * ddw ==> ddy(N, Cout, oH, oW), x/ddx(N, Cin, H, W),
// w/ddw(Cout, Cin, kh, kw)
// ddy convolution double grad: im2col(vol2col) + gemm
if (ddY) {
dev_ctx.template Alloc<T>(ddY);
DenseTensor transformed_ddY(ddY->type());
if (channel_last) {
ResizeToChannelFirst<Context, T>(dev_ctx, ddY, &transformed_ddY);
} else {
transformed_ddY = *ddY;
}
set_zero(dev_ctx, &transformed_ddY, static_cast<T>(0));
paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
im2col;
paddle::operators::math::Vol2ColFunctor<Context, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
DenseTensor ddy_batch =
transformed_ddY.Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; ++g) {
// gemm
DenseTensor ddy_slice =
ddy_batch.Slice(g * out_step, (g + 1) * out_step);
if (ddX) {
DenseTensor ddx_batch =
transformed_ddX.Slice(i, i + 1).Resize(input_shape);
DenseTensor ddx_slice =
ddx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(ddx_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
im2col(dev_ctx,
ddx_slice,
dilations,
strides,
std::vector<int>{
paddings[0], paddings[2], paddings[1], paddings[3]},
&col);
} else if (data_dim == 3U) {
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
}
DenseTensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(
w_slice, false, col_matrix, false, T(1.0), &ddy_slice, T(0.0));
}
if (ddW_in) {
DenseTensor x_batch =
transformed_X.Slice(i, i + 1).Resize(input_shape);
DenseTensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
DenseTensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
if (!is_expand) {
col.ShareDataWith(x_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
im2col(dev_ctx,
x_slice,
dilations,
strides,
std::vector<int>{
paddings[0], paddings[2], paddings[1], paddings[3]},
&col);
} else if (data_dim == 3U) {
vol2col(dev_ctx, x_slice, dilations, strides, paddings, &col);
}
// gemm
DenseTensor ddw_slice = ddW.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(
ddw_slice, false, col_matrix, false, T(1.0), &ddy_slice, T(1.0));
}
}
}
if (channel_last) {
TransToChannelLast<Context, T>(dev_ctx, &transformed_ddY, ddY);
}
}
}
} // namespace phi
......@@ -25,7 +25,7 @@
namespace phi {
template <typename T, typename Context>
void ConvKernel(const Context& dev_ctx,
void ConvKernelImpl(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter_t,
const std::vector<int>& strides,
......@@ -34,9 +34,6 @@ void ConvKernel(const Context& dev_ctx,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* output) {
std::vector<int> paddings = paddings_t;
std::vector<int> dilations = dilations_t;
......
......@@ -13,9 +13,6 @@
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows
#endif
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows
#endif
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
auto eigen_in = EigenVector<T>::Flatten(x);
auto eigen_out = EigenVector<T>::Flatten(*out);
auto& place = *ctx.eigen_device();
constexpr T half = static_cast<T>(0.5);
constexpr T half_sqrt = static_cast<T>(M_SQRT1_2);
eigen_out.device(place) = (eigen_in * half + half).ndtri() * half_sqrt;
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/conv_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h"
#include "paddle/phi/kernels/onednn/conv_handler.h"
namespace phi {
#define PD_VISIT_FLOAT_AND_BF16_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::BFLOAT16, \
::phi::dtype::bfloat16, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
template <typename T, typename Context>
void ConvGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
DenseTensor* input_grad,
DenseTensor* filter_grad) {
PADDLE_ENFORCE_EQ(dev_ctx.GetPlace().GetType(),
AllocationType::CPU,
phi::errors::PreconditionNotMet(
"Operator DNNL ConvGrad must use CPUPlace"));
const auto& onednn_engine = dev_ctx.GetEngine();
const auto* bias =
dev_ctx.HasDnnInput("Bias") ? dev_ctx.GetDnnInput("Bias") : nullptr;
bool is_test = dev_ctx.HasDnnAttr("is_test")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test"))
: false;
if (!input_grad && !filter_grad) return;
const std::string& unique_name =
dev_ctx.GetInputsName("Input")[0] + dev_ctx.GetInputsName("Filter")[0];
PD_VISIT_FLOAT_AND_BF16_TYPES(
filter.dtype(), "ConvOneDNNHandlerT", ([&] {
// TODO(jczaja): Are all tensors really needed?
onednn::ConvOneDNNHandlerT<T, data_t, T> handler(dev_ctx,
dev_ctx.GetPlace(),
&input,
&filter,
bias,
&out_grad,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
is_test,
filter_grad,
input_grad,
unique_name);
// create mkldnn memory from input tensors (data/weights)
auto& astream = OneDNNContext::tls().get_stream();
if (filter_grad) {
auto src_memory_p =
handler.AcquireSrcMemoryWithReorderFromWeightsPrimitive(&input);
auto diff_dst_memory_p =
handler.AcquireDiffDstMemoryWithReorderFromWeightsPrimitive(
&out_grad);
// For convoluition with groups write filter grad into
// oneDNN buffer and then we reorder it into filter_grad tensor
int g = std::max(groups, 1);
auto diff_weights_memory_p =
g > 1 ? handler.AcquireDiffWeightsMemory()
: handler.AcquireDiffWeightsMemory(filter_grad);
auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive();
conv_bwd_weights_p->execute(
astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
astream.wait();
// For convolution with groups convert from blocked to NCHW
// otherwise there will be problems in next operators working on
// this data
if (g > 1) {
// in OneDNN groups in convolution are treated as separate
// dimension which is not the case in paddlepaddle
dnnl::memory::data_type in_type =
funcs::ToOneDNNDataType(filter.dtype());
// for 3d conv with groups (six dimensional data reorder to
// goidhw) for 2d conv with groups (five dimensional data reorder
// to goihw) auto weights_tz = phi::vectorize(filter->dims());
auto weights_tz = diff_weights_memory_p->get_desc().dims();
dnnl::memory::format_tag out_format =
weights_tz.size() == 6 ? dnnl::memory::format_tag::goidhw
: dnnl::memory::format_tag::goihw;
funcs::ReorderOneDNNHandler handler(
weights_tz, filter.dtype(), in_type, onednn_engine);
auto reorder_dst_memory_p = handler.AcquireDstMemory(
filter_grad, out_format, dev_ctx.GetPlace());
auto reorder_p = handler.AcquireReorder(reorder_dst_memory_p,
diff_weights_memory_p);
{
reorder_p->execute(
astream, *diff_weights_memory_p, *reorder_dst_memory_p);
astream.wait();
}
// So here we have a data in goihw , which can be interpreted as
// OIHW (OIDHW for conv3d) because filter_grad shape is set for
// OIHW (OIDHW for conv3d)
dnnl::memory::format_tag target_format =
weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw
: dnnl::memory::format_tag::oihw;
filter_grad->set_mem_desc(
dnnl::memory::desc(phi::vectorize<int64_t>(filter_grad->dims()),
in_type,
target_format));
} else {
filter_grad->set_mem_desc(diff_weights_memory_p->get_desc());
}
}
if (input_grad) {
auto weights_memory_p =
handler.AcquireWeightsMemoryWithReorderFromDataPrimitive(
&filter, groups, strides.size() == 3U);
auto diff_dst_memory_p =
handler.AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive(
&out_grad);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(input_grad);
auto conv_bwd_data_p = handler.AcquireBackwardPrimitive();
conv_bwd_data_p->execute(astream,
{{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
input_grad->set_mem_desc(diff_src_memory_p->get_desc());
}
}));
}
} // namespace phi
PD_REGISTER_KERNEL(conv2d_grad,
OneDNN,
ONEDNN,
phi::ConvGradKernel,
float,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/onednn/onednn_helper.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/expect.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace phi {
namespace onednn {
inline funcs::OneDNNMemoryFormat GetWeightsFormat(int groups, bool is_conv3d) {
if (is_conv3d) {
return (groups == 1) ? funcs::OneDNNMemoryFormat::oidhw
: funcs::OneDNNMemoryFormat::goidhw;
} else {
return (groups == 1) ? funcs::OneDNNMemoryFormat::oihw
: funcs::OneDNNMemoryFormat::goihw;
}
}
template <typename T, typename K, typename T_out>
class ConvOneDNNHandlerT
: public funcs::OneDNNHandlerT<T,
dnnl::convolution_forward,
dnnl::convolution_backward_data,
dnnl::convolution_backward_weights> {
public:
ConvOneDNNHandlerT(const OneDNNContext& dev_ctx,
const dnnl::engine mkldnn_engine,
Place cpu_place,
const phi::DenseTensor* input,
const phi::DenseTensor* filter,
const phi::DenseTensor* bias,
const std::vector<int>& strides_in,
const std::vector<int>& paddings_in,
const std::string& padding_algorithm,
const std::vector<int>& dilations_in,
int groups,
const std::string& data_format,
bool is_test,
bool is_BFLOAT16,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
phi::DenseTensor* output,
const std::string& unique_name)
: funcs::OneDNNHandlerT<T,
dnnl::convolution_forward,
dnnl::convolution_backward_data,
dnnl::convolution_backward_weights>(
dev_ctx,
mkldnn_engine,
cpu_place,
funcs::CreateKey(
dev_ctx, phi::vectorize(input->dims()), unique_name)) {
if (unlikely(!this->isCached())) {
PADDLE_ENFORCE_EQ(
input->layout(),
DataLayout::ONEDNN,
phi::errors::InvalidArgument(
"The input tensor's layout should be %d, but got %d.",
DataLayout::ONEDNN,
input->layout()));
PADDLE_ENFORCE_EQ(
filter->layout(),
DataLayout::ONEDNN,
phi::errors::InvalidArgument(
"The Filter tensor's layout should be %d, but got %d.",
DataLayout::ONEDNN,
filter->layout()));
PADDLE_ENFORCE_GE(
input->dims().size(),
4,
phi::errors::InvalidArgument(
"Input must be with 4 or 5 dimensions, i.e. NCHW or "
"NCDHW, but got dimension = %d .",
input->dims().size()));
PADDLE_ENFORCE_LE(
input->dims().size(),
5,
phi::errors::InvalidArgument(
"Input must be with 4 or 5 dimensions, i.e. NCHW or "
"NCDHW, but got dimension = %d .",
input->dims().size()));
PADDLE_ENFORCE_GE(
filter->dims().size(),
4,
phi::errors::InvalidArgument(
"Filter must be with 4 or 5 dimensions, i.e. OIHW or "
"OIDHW, but got dimension = %d .",
filter->dims().size()));
PADDLE_ENFORCE_LE(
filter->dims().size(),
5,
phi::errors::InvalidArgument(
"Filter must be with 4 or 5 dimensions, i.e. OIHW or "
"OIDHW, but got dimension = %d .",
filter->dims().size()));
if (bias) {
PADDLE_ENFORCE_EQ(
bias->layout(),
DataLayout::ONEDNN,
phi::errors::InvalidArgument(
"The Bias tensor's layout should be %d, but got %d.",
DataLayout::ONEDNN,
bias->layout()));
PADDLE_ENFORCE_EQ(
bias->dims().size(),
1,
phi::errors::InvalidArgument("Bias must only have 1 dimension, "
"i.e. X, but got dimension = %d .",
bias->dims().size()));
}
const auto input_dims = input->dims();
const auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size());
const auto filter_dims = filter->dims();
const auto filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
const auto ksize = phi::vectorize(filter_data_dims);
std::vector<int64_t> strides(begin(strides_in), end(strides_in));
std::vector<int64_t> paddings(begin(paddings_in), end(paddings_in));
std::vector<int64_t> dilations(begin(dilations_in), end(dilations_in));
UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, data_dims, strides, ksize);
std::transform(
dilations.begin(), dilations.end(), dilations.begin(), [](int64_t i) {
return i - 1;
});
const auto src_tz = phi::vectorize(input->dims());
auto weights_tz = phi::vectorize(filter->dims());
funcs::GetGroupConvWeightsTz(weights_tz, groups);
const auto dst_tz = phi::vectorize(output->dims());
const dnnl::memory::dims stride_dims = strides;
const auto onednn_paddings = funcs::ToOnednnPadding(paddings);
const dnnl::memory::dims dilations_dims = dilations;
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
auto chosen_memory_format = funcs::OneDNNMemoryFormat::any;
auto data_type = dnnl::memory::data_type::f32;
if (is_BFLOAT16 || std::is_same<T_out, dtype::bfloat16>::value) {
data_type = dnnl::memory::data_type::bf16;
}
dnnl::memory::desc src_md, weights_md;
if (funcs::is_int8<T>()) {
src_md = funcs::OneDNNMemDesc(src_tz,
funcs::ToOneDNNDataType(input->dtype()),
chosen_memory_format);
weights_md = funcs::OneDNNMemDesc(
weights_tz, dnnl::memory::data_type::s8, chosen_memory_format);
} else {
src_md = funcs::OneDNNMemDesc(src_tz, data_type, chosen_memory_format);
weights_md = funcs::OneDNNMemDesc(
weights_tz, data_type, funcs::OneDNNMemoryFormat::any);
}
const auto dst_md = funcs::OneDNNMemDesc(
dst_tz, funcs::OneDNNGetDataType<T_out>(), chosen_memory_format);
const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training;
const dnnl::primitive_attr conv_attr = CreateConvAttrs(filter,
groups,
force_fp32_output,
fuse_residual_conn,
fuse_activation);
if (bias) {
auto bias_tz = phi::vectorize(bias->dims());
dnnl::memory::desc bias_md;
if (funcs::is_int8<T>()) {
bias_md = funcs::OneDNNMemDesc(bias_tz,
dnnl::memory::data_type::s32,
funcs::OneDNNMemoryFormat::x);
} else {
bias_md = funcs::OneDNNMemDesc(
bias_tz, data_type, funcs::OneDNNMemoryFormat::x);
}
this->AcquireForwardPrimitiveDescriptor(
conv_attr,
fwd_prop_kind,
dnnl::algorithm::convolution_direct,
src_md,
weights_md,
bias_md,
dst_md,
stride_dims,
dilations_dims,
onednn_paddings[0],
onednn_paddings[1]);
} else {
this->AcquireForwardPrimitiveDescriptor(
conv_attr,
fwd_prop_kind,
dnnl::algorithm::convolution_direct,
src_md,
weights_md,
dst_md,
stride_dims,
dilations_dims,
onednn_paddings[0],
onednn_paddings[1]);
}
}
}
ConvOneDNNHandlerT(const OneDNNContext& dev_ctx,
Place cpu_place,
const phi::DenseTensor* in,
const phi::DenseTensor* filter,
const phi::DenseTensor* bias,
const phi::DenseTensor* out_grad,
const std::vector<int>& strides_in,
const std::vector<int>& paddings_in,
const std::string& padding_algorithm,
const std::vector<int>& dilations_in,
int groups,
const std::string& data_format,
bool is_test,
phi::DenseTensor* filter_grad,
phi::DenseTensor* in_x_grad,
const std::string& unique_name)
: funcs::OneDNNHandlerT<T,
dnnl::convolution_forward,
dnnl::convolution_backward_data,
dnnl::convolution_backward_weights>(
dev_ctx,
dev_ctx.GetEngine(),
cpu_place,
funcs::CreateKey(
dev_ctx, phi::vectorize(in->dims()), unique_name)) {
if (unlikely(!this->isBwdCached())) {
PADDLE_ENFORCE_EQ(
in->layout(),
DataLayout::ONEDNN,
phi::errors::InvalidArgument(
"The input tensor's layout should be %d, but got %d.",
DataLayout::ONEDNN,
in->layout()));
PADDLE_ENFORCE_EQ(
filter->layout(),
DataLayout::ONEDNN,
phi::errors::InvalidArgument(
"The filter tensor's layout should be %d, but got %d.",
DataLayout::ONEDNN,
filter->layout()));
PADDLE_ENFORCE_EQ(
out_grad->layout(),
DataLayout::ONEDNN,
phi::errors::InvalidArgument(
"The output_grad tensor's layout should be %d, but got %d.",
DataLayout::ONEDNN,
out_grad->layout()));
PADDLE_ENFORCE_EQ(
is_test,
false,
phi::errors::InvalidArgument(
"is_test attribute should be set to False in training phase."));
std::vector<int64_t> strides(begin(strides_in), end(strides_in));
std::vector<int64_t> paddings(begin(paddings_in), end(paddings_in));
std::vector<int64_t> dilations(begin(dilations_in), end(dilations_in));
auto input_dims = in->dims();
auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size());
auto filter_dims = filter->dims();
auto filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = phi::vectorize(filter_data_dims);
UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, data_dims, strides, ksize);
auto src_tz = phi::vectorize(in->dims());
auto weights_tz = phi::vectorize(filter->dims());
int g = std::max(groups, 1);
funcs::GetGroupConvWeightsTz(weights_tz, g);
auto dst_tz = phi::vectorize(out_grad->dims());
/* create memory descriptor for conv backward without specified format
* ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance
*/
const auto chosen_memory_format = funcs::OneDNNMemoryFormat::any;
const auto weights_format = funcs::OneDNNMemoryFormat::any;
auto src_md = funcs::OneDNNMemDesc(
src_tz, funcs::OneDNNGetDataType<T>(), chosen_memory_format);
const auto dst_md = funcs::OneDNNMemDesc(
dst_tz, funcs::OneDNNGetDataType<T_out>(), chosen_memory_format);
auto diff_src_md = funcs::OneDNNMemDesc(
src_tz, funcs::OneDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = funcs::OneDNNMemDesc(
weights_tz, funcs::OneDNNGetDataType<T>(), weights_format);
auto diff_weights_md = funcs::OneDNNMemDesc(
weights_tz, funcs::OneDNNGetDataType<T>(), weights_format);
auto diff_dst_md = funcs::OneDNNMemDesc(
dst_tz, funcs::OneDNNGetDataType<T>(), chosen_memory_format);
auto onednn_paddings = funcs::ToOnednnPadding(paddings);
std::transform(
dilations.begin(), dilations.end(), dilations.begin(), [](int64_t i) {
return i - 1;
});
const dnnl::memory::dims dilations_dims = dilations;
const dnnl::memory::dims stride_dims = strides;
// Recreating FWD PD. For training there are no post ops in convolution
dnnl::primitive_attr conv_attr;
if (bias) {
auto bias_tz = phi::vectorize(bias->dims());
dnnl::memory::desc bias_md;
if (funcs::is_int8<T>()) {
bias_md = funcs::OneDNNMemDesc(bias_tz,
dnnl::memory::data_type::s32,
funcs::OneDNNMemoryFormat::x);
} else {
bias_md = funcs::OneDNNMemDesc(bias_tz,
dnnl::memory::data_type::f32,
funcs::OneDNNMemoryFormat::x);
}
this->AcquireForwardPrimitiveDescriptor(
conv_attr,
dnnl::prop_kind::forward_training,
dnnl::algorithm::convolution_direct,
src_md,
weights_md,
bias_md,
dst_md,
stride_dims,
dilations_dims,
onednn_paddings[0],
onednn_paddings[1]);
} else {
this->AcquireForwardPrimitiveDescriptor(
conv_attr,
dnnl::prop_kind::forward_training,
dnnl::algorithm::convolution_direct,
src_md,
weights_md,
dst_md,
stride_dims,
dilations_dims,
onednn_paddings[0],
onednn_paddings[1]);
}
this->AcquireBackwardPrimitiveDescriptor(
dnnl::algorithm::convolution_direct,
diff_src_md,
weights_md,
diff_dst_md,
strides,
dilations_dims,
onednn_paddings[0],
onednn_paddings[1]);
this->AcquireBackwardWeightsPrimitiveDescriptor(
dnnl::algorithm::convolution_direct,
src_md,
diff_weights_md,
diff_dst_md,
strides,
dilations_dims,
onednn_paddings[0],
onednn_paddings[1]);
}
}
std::shared_ptr<std::tuple<float, std::vector<float>>> get_int8_bias_scales(
const DenseTensor* filter,
int groups,
const std::vector<float>& scale_weights_data) {
// Get scales int8 bias key
const std::string key_bs = this->key_ + "@bs";
// Scales for int8 bias are to be cached to avoid
// computing them each iteration
groups = std::max(groups, 1);
auto bias_scale_tuple =
std::static_pointer_cast<std::tuple<float, std::vector<float>>>(
this->dev_ctx_.GetBlob(key_bs));
if (bias_scale_tuple) return bias_scale_tuple;
const auto& weights_tz = phi::vectorize(filter->dims());
const auto& scale_in_data =
this->dev_ctx_.HasDnnAttr("Scale_in")
? PADDLE_GET_CONST(float, this->dev_ctx_.GetDnnAttr("Scale_in"))
: 1.0f;
bool is_multi_channel = scale_weights_data.size() > 1;
int mask_reorder = is_multi_channel ? 1 << 0 : 1;
int count = 1;
if (is_multi_channel) {
count *= weights_tz[0];
if (groups > 1) {
count *= weights_tz[1];
}
}
bias_scale_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(std::make_tuple(
static_cast<float>(mask_reorder), std::vector<float>(count)));
for (int i = 0; i < count; i++) {
std::get<1>(*bias_scale_tuple)[i] = scale_in_data * scale_weights_data[i];
}
this->dev_ctx_.SetBlob(key_bs, bias_scale_tuple);
return bias_scale_tuple;
}
std::tuple<float, std::vector<float>, float> get_int8_scales(
const DenseTensor* filter,
int groups,
bool force_fp32_output,
bool fuse_residual_conn,
const std::string& fuse_activation) const {
const auto& weights_tz = phi::vectorize(filter->dims());
groups = std::max(groups, 1);
const auto& scale_weights_data =
this->dev_ctx_.HasDnnAttr("Scale_weights")
? PADDLE_GET_CONST(std::vector<float>,
this->dev_ctx_.GetDnnAttr("Scale_weights"))
: std::vector<float>{1.0f};
const auto& scale_in_data =
this->dev_ctx_.HasDnnAttr("Scale_in")
? PADDLE_GET_CONST(float, this->dev_ctx_.GetDnnAttr("Scale_in"))
: 1.0f;
const auto& scale_in_eltwise_data =
this->dev_ctx_.HasDnnAttr("Scale_in_eltwise")
? PADDLE_GET_CONST(float,
this->dev_ctx_.GetDnnAttr("Scale_in_eltwise"))
: 1.0f;
bool is_multi_channel = scale_weights_data.size() > 1;
bool has_activation = !fuse_activation.empty();
const auto& scale_out =
this->dev_ctx_.HasDnnAttr("Scale_out")
? PADDLE_GET_CONST(float, this->dev_ctx_.GetDnnAttr("Scale_out"))
: 1.0f;
float activation_scale =
(!force_fp32_output && has_activation) ? scale_out : 1.0f;
float scale_out_data =
(force_fp32_output || has_activation) ? 1.0f : scale_out;
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
int count =
is_multi_channel
? (groups > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0])
: 1;
std::vector<float> output_shift_scale(count);
#pragma omp parallel for if (count > 50)
for (int i = 0; i < count; i++) {
if (scale_weights_data[i] == 0.0)
// weights data will contain 0 in some models, then weights
// scale couldn't be calculated
output_shift_scale[i] = scale_out_data;
else
output_shift_scale[i] =
static_cast<float>(static_cast<double>(scale_out_data) /
(static_cast<double>(scale_in_data) *
static_cast<double>(scale_weights_data[i])));
}
return std::make_tuple(sum_scale, output_shift_scale, activation_scale);
}
dnnl::primitive_attr CreateConvAttrs(const DenseTensor* filter,
int groups,
bool force_fp32_output,
bool fuse_residual_conn,
const std::string& fuse_activation) {
dnnl::primitive_attr conv_attr;
dnnl::post_ops post_operations;
float sum_scale = 1.0f;
float activation_scale = 1.0f;
std::vector<float> output_shift_scale;
if (funcs::is_int8<T>()) {
if (this->dev_ctx_.HasDnnAttr("Sum_scale")) {
sum_scale =
PADDLE_GET_CONST(float, this->dev_ctx_.GetDnnAttr("Sum_scale"));
activation_scale =
this->dev_ctx_.HasDnnAttr("Activation_scale")
? PADDLE_GET_CONST(
float, this->dev_ctx_.GetDnnAttr("Activation_scale"))
: activation_scale;
output_shift_scale =
this->dev_ctx_.HasDnnAttr("Output_shift_scale")
? PADDLE_GET_CONST(
std::vector<float>,
this->dev_ctx_.GetDnnAttr("Output_shift_scale"))
: output_shift_scale;
} else {
std::tie(sum_scale, output_shift_scale, activation_scale) =
get_int8_scales(filter,
groups,
force_fp32_output,
fuse_residual_conn,
fuse_activation);
}
if (output_shift_scale.size() > 0) {
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
}
}
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(sum_scale);
}
funcs::AppendActivation(this->dev_ctx_, post_operations, activation_scale);
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::shared_ptr<dnnl::memory>
AcquireWeightsMemoryWithReorderFromDataPrimitive(
const phi::DenseTensor* filter, const int groups, const bool is_conv3d) {
const K* filter_data = filter->data<K>();
auto weights_tz = phi::vectorize(filter->dims());
funcs::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md =
funcs::OneDNNMemDesc(weights_tz,
funcs::OneDNNGetDataType<K>(),
GetWeightsFormat(groups, is_conv3d));
return this->AcquireMemoryWithReorder(user_src_md,
this->bwd_pd_->weights_desc(),
funcs::to_void_cast<K>(filter_data),
"@weights_mem_d_p",
false);
}
std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
const phi::DenseTensor* input) {
return this->AcquireMemoryWithReorderPrimitive(input,
"@src_mem_p_user",
"@src_mem_p_target",
"@src_mem_p",
this->fwd_pd_->src_desc());
}
std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorderFromWeightsPrimitive(
const phi::DenseTensor* input) {
return this->AcquireMemoryWithReorderPrimitive(input,
"@src_mem_w_p_user",
"@src_mem_w_p_target",
"@src_mem_w_p",
this->bwd_w_pd_->src_desc());
}
std::shared_ptr<dnnl::memory>
AcquireDiffDstMemoryWithReorderFromWeightsPrimitive(
const phi::DenseTensor* out_grad) {
return this->AcquireMemoryWithReorderPrimitive(
out_grad,
"@diff_dst_mem_w_p_user",
"@diff_dst_mem_w_p_target",
"@diff_dst_mem_w_p",
this->bwd_w_pd_->diff_dst_desc());
}
std::shared_ptr<dnnl::memory>
AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive(
const phi::DenseTensor* out_grad) {
return this->AcquireMemoryWithReorderPrimitive(
out_grad,
"@diff_dst_mem_p_user",
"@diff_dst_mem_p_target",
"@diff_dst_mem_p",
this->bwd_pd_->diff_dst_desc());
}
std::shared_ptr<dnnl::memory> AcquireMemoryWithReorderPrimitive(
const phi::DenseTensor* in_mem,
const char* key_mem_user,
const char* key_mem_target,
const char* key_mem,
const dnnl::memory::desc& mem_md) {
const T* in_mem_data = in_mem->data<T>();
const std::string user_key_suffix{key_mem_user};
auto user_mem_p = this->AcquireMemory(user_key_suffix);
if (!user_mem_p) {
return this->AcquireMemoryWithReorder(in_mem->mem_desc(),
mem_md,
funcs::to_void_cast<T>(in_mem_data),
key_mem);
} else {
const std::string target_key_suffix{key_mem_target};
const auto target_mem_p = this->AcquireMemory(target_key_suffix);
user_mem_p->set_data_handle(funcs::to_void_cast<T>(in_mem_data));
if (user_mem_p != target_mem_p) {
this->AcquireReorder(user_mem_p, target_mem_p);
}
return target_mem_p;
}
}
std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
const phi::DenseTensor* filter,
const int groups,
const bool is_conv3d,
const bool is_test,
const std::vector<float>& scale_data = {1.0f},
int mask = 0) {
// This is workaround to make execution faster, delete
// if statement after including md inside Tensor
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
if (is_test && weights_mem_p) {
return weights_mem_p;
} else if (is_test) {
const K* filter_data = filter->data<K>();
auto weights_tz = phi::vectorize(filter->dims());
funcs::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md =
funcs::OneDNNMemDesc(weights_tz,
funcs::OneDNNGetDataType<K>(),
GetWeightsFormat(groups, is_conv3d));
return this->AcquireMemoryWithReorder(user_src_md,
this->fwd_pd_->weights_desc(),
funcs::to_void_cast<K>(filter_data),
"@weights_mem_p",
is_test,
{},
scale_data,
mask);
} else {
const T* filter_data = filter->data<T>();
auto weights_tz = phi::vectorize(filter->dims());
funcs::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md =
funcs::OneDNNMemDesc(weights_tz,
funcs::OneDNNGetDataType<T>(),
GetWeightsFormat(groups, is_conv3d));
return this->AcquireMemoryWithReorder(user_src_md,
this->fwd_pd_->weights_desc(),
funcs::to_void_cast<T>(filter_data),
"@weights_mem_p",
is_test,
{},
scale_data,
mask);
}
}
std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
const phi::DenseTensor* bias,
const bool is_test,
const std::vector<float>& scale_data = {1.0f},
int mask = 0) {
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
if (is_test && bias_mem_p) {
return bias_mem_p;
} else {
// if K is int8 (weights are int8) then biases are int32
using K_Bias = typename std::
conditional<std::is_same<K, int8_t>::value, int32_t, K>::type;
if (std::is_same<K_Bias, int32_t>::value &&
bias->dtype() != phi::DataType::INT32) {
LOG(ERROR) << "Bias should be of type int32 but is " << bias->dtype();
}
const K_Bias* bias_data = bias->data<K_Bias>();
return this->AcquireMemoryWithReorder(
bias->mem_desc(),
this->fwd_pd_->bias_desc(),
funcs::to_void_cast<K_Bias>(bias_data),
"@bias_mem_p",
is_test,
{},
scale_data,
mask);
}
}
std::shared_ptr<dnnl::memory> AcquireResidualMemory(
const phi::DenseTensor* residual_param) {
void* residual_data =
residual_param->dtype() ==
paddle::experimental::CppTypeToDataType<T_out>::Type()
? funcs::to_void_cast<T_out>(residual_param->data<T_out>())
: funcs::to_void_cast<T>(residual_param->data<T>());
auto residual_mem_p = this->AcquireMemory("@user_residual_data_mem_p");
if (residual_mem_p) {
residual_mem_p->set_data_handle(residual_data);
return residual_mem_p;
} else {
return this->AcquireMemoryFromPrimitive(residual_param->mem_desc(),
residual_data,
"@user_residual_data_mem_p");
}
}
std::shared_ptr<dnnl::memory> AcquireDstMemoryWithResidual(
phi::DenseTensor* output, const phi::DenseTensor* residual_param) {
std::shared_ptr<dnnl::memory> dst_memory_p;
if (residual_param->mem_desc() != this->fwd_pd_->dst_desc()) {
auto residual_memory_p = this->AcquireResidualMemory(residual_param);
dst_memory_p = this->template AcquireDstMemory<T_out>(output);
this->AcquireReorder(residual_memory_p, dst_memory_p);
} else {
// Changing ShareDataWith to TensorCopy results in performance drop
// on ResNet architectures
// (https://github.com/PaddlePaddle/Paddle/issues/22964)
output->ShareDataWith(*residual_param);
dst_memory_p = this->template AcquireDstMemory<T_out>(output);
}
return dst_memory_p;
}
};
} // namespace onednn
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/conv_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h"
#include "paddle/phi/kernels/onednn/conv_handler.h"
namespace phi {
static dnnl::memory::data_type GetDstType(
bool is_int8,
bool is_bfloat16,
bool force_fp32_output,
std::string fuse_activation,
bool fuse_residual_conn,
const phi::DenseTensor* residual_param) {
auto dst_dt = dnnl::memory::data_type::f32;
if (is_int8) {
dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6")
? dnnl::memory::data_type::u8
: dnnl::memory::data_type::s8;
if (force_fp32_output) {
dst_dt = dnnl::memory::data_type::f32;
}
if (fuse_residual_conn && residual_param) {
auto residual_dt = funcs::ToOneDNNDataType(residual_param->dtype());
if (dst_dt != residual_dt) dst_dt = residual_dt;
}
} else {
if (!force_fp32_output && is_bfloat16) {
dst_dt = dnnl::memory::data_type::bf16;
if (fuse_residual_conn && residual_param) {
dst_dt = funcs::ToOneDNNDataType(residual_param->dtype());
}
}
}
return dst_dt;
}
#define PD_VISIT_FLOAT_AND_INT8_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
template <typename T, typename T_out>
void ComputeFP32(const OneDNNContext& dev_ctx,
const DenseTensor* input,
const DenseTensor* filter,
const DenseTensor* bias,
const DenseTensor* residual_param,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
bool is_test,
bool is_BFLOAT16,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
DenseTensor* output) {
const auto& onednn_engine = dev_ctx.GetEngine();
const bool is_conv3d = strides.size() == 3U;
const std::string& unique_name =
dev_ctx.GetInputsName("Input")[0] + dev_ctx.GetInputsName("Filter")[0];
PD_VISIT_FLOAT_AND_INT8_TYPES(
filter->dtype(), "ConvOneDNNHandlerT", ([&] {
onednn::ConvOneDNNHandlerT<T, data_t, T_out> handler(dev_ctx,
onednn_engine,
dev_ctx.GetPlace(),
input,
filter,
bias,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
is_test,
is_BFLOAT16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
output,
unique_name);
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, groups, is_conv3d, is_test);
std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) {
dst_memory_p =
handler.AcquireDstMemoryWithResidual(output, residual_param);
} else {
dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
}
auto conv_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
if (bias) {
auto bias_memory_p =
handler.AcquireBiasMemoryWithReorder(bias, is_test);
args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
auto& astream = OneDNNContext::tls().get_stream();
conv_p->execute(astream, args);
astream.wait();
output->set_mem_desc(dst_memory_p->get_desc());
}));
}
template <typename T, typename T_out>
void ComputeINT8(const OneDNNContext& dev_ctx,
const DenseTensor* input,
const DenseTensor* filter,
const DenseTensor* bias,
const DenseTensor* residual_param,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
bool is_test,
bool is_BFLOAT16,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
DenseTensor* output) {
const auto& onednn_engine = dev_ctx.GetEngine();
const bool is_conv3d = strides.size() == 3U;
bool unsigned_output =
(fuse_activation == "relu" || fuse_activation == "relu6");
bool need_s8_to_u8 = false;
PADDLE_ENFORCE_NE(
is_conv3d,
true,
phi::errors::Unimplemented(
"OneDNN int8 convolution does not support 3D inputs currently"));
PADDLE_ENFORCE_EQ(
fuse_residual_conn && force_fp32_output,
false,
phi::errors::Unimplemented(
"residual fusion does not support force output with fp32"));
const std::string& unique_name =
dev_ctx.GetInputsName("Input")[0] + dev_ctx.GetInputsName("Filter")[0];
PD_VISIT_FLOAT_AND_INT8_TYPES(
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
onednn::ConvOneDNNHandlerT<T, data_t, T_out> handler(dev_ctx,
onednn_engine,
dev_ctx.GetPlace(),
input,
filter,
bias,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
is_test,
is_BFLOAT16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
output,
unique_name);
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
const auto& scale_weights_data =
dev_ctx.HasDnnAttr("Scale_weights")
? PADDLE_GET_CONST(std::vector<float>,
dev_ctx.GetDnnAttr("Scale_weights"))
: std::vector<float>{1.0f};
const bool is_multi_channel = scale_weights_data.size() > 1;
int mask_reorder = is_multi_channel
? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0)
: 0;
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, groups, false, true, scale_weights_data, mask_reorder);
std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) {
PADDLE_ENFORCE_EQ(
output->dims(),
residual_param->dims(),
phi::errors::InvalidArgument(
"Output and elementwise parameter need to have the "
"same dimension sizes, but got output's dimension = %d"
" and residual param's dimension =%d .",
output->dims().size(),
residual_param->dims().size()));
dst_memory_p =
handler.AcquireDstMemoryWithResidual(output, residual_param);
need_s8_to_u8 = (funcs::OneDNNGetDataType<T_out>() ==
dnnl::memory::data_type::s8) &&
unsigned_output;
} else {
dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
}
auto conv_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
if (bias) {
std::vector<float> bias_scales;
auto p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder),
bias_scales));
if (dev_ctx.HasDnnAttr("Bias_scales")) {
bias_scales = PADDLE_GET_CONST(std::vector<float>,
dev_ctx.GetDnnAttr("Bias_scales"));
p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder),
bias_scales));
} else {
p_scales_tuple = handler.get_int8_bias_scales(
filter, groups, scale_weights_data);
}
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
bias,
true,
std::get<1>(*p_scales_tuple),
std::get<0>(*p_scales_tuple));
args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
auto& astream = OneDNNContext::tls().get_stream();
conv_p->execute(astream, args);
astream.wait();
if (need_s8_to_u8) {
dev_ctx.Alloc<uint8_t>(output);
}
output->set_mem_desc(dst_memory_p->get_desc());
}));
}
template <typename T, typename Context>
void ConvKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType(),
AllocationType::CPU,
phi::errors::PreconditionNotMet("Operator DNNL Conv must use CPUPlace"));
bool is_INT8 =
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
bool is_test = dev_ctx.HasDnnAttr("is_test")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test"))
: false;
bool is_BFLOAT16 =
dev_ctx.HasDnnAttr("mkldnn_data_type")
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("mkldnn_data_type")) ==
"bfloat16"
: false;
const auto* bias =
dev_ctx.HasDnnInput("Bias") ? dev_ctx.GetDnnInput("Bias") : nullptr;
const auto* residual_param = dev_ctx.HasDnnInput("ResidualData")
? dev_ctx.GetDnnInput("ResidualData")
: nullptr;
bool fuse_residual_conn =
dev_ctx.HasDnnAttr("fuse_residual_connection")
? PADDLE_GET_CONST(bool,
dev_ctx.GetDnnAttr("fuse_residual_connection"))
: false;
const std::string& fuse_activation =
dev_ctx.HasDnnAttr("fuse_activation")
? PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation"))
: "";
bool force_fp32_output =
dev_ctx.HasDnnAttr("force_fp32_output")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
: false;
auto dst_dt = GetDstType(is_INT8,
is_BFLOAT16,
force_fp32_output,
fuse_activation,
fuse_residual_conn,
residual_param);
if (!is_INT8) {
if (dst_dt == dnnl::memory::data_type::f32) {
ComputeFP32<T, float>(dev_ctx,
&input,
&filter,
bias,
residual_param,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
is_test,
is_BFLOAT16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
out);
} else if (dst_dt == dnnl::memory::data_type::bf16) {
ComputeFP32<T, dtype::bfloat16>(dev_ctx,
&input,
&filter,
bias,
residual_param,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
is_test,
is_BFLOAT16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
out);
}
} else {
if (dst_dt == dnnl::memory::data_type::f32) {
ComputeINT8<T, float>(dev_ctx,
&input,
&filter,
bias,
residual_param,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
is_test,
is_BFLOAT16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
out);
} else if (dst_dt == dnnl::memory::data_type::u8) {
ComputeINT8<T, uint8_t>(dev_ctx,
&input,
&filter,
bias,
residual_param,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
is_test,
is_BFLOAT16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
out);
} else if (dst_dt == dnnl::memory::data_type::s8) {
ComputeINT8<T, int8_t>(dev_ctx,
&input,
&filter,
bias,
residual_param,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
is_test,
is_BFLOAT16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
out);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(conv2d,
OneDNN,
ONEDNN,
phi::ConvKernel,
float,
phi::dtype::bfloat16,
uint8_t,
int8_t) {}
......@@ -28,12 +28,9 @@ void ConvGradKernel(const Context& dev_ctx,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
int groups,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* input_grad,
DenseTensor* filter_grad) {
using XPUT = typename XPUTypeTrait<T>::Type;
......@@ -151,7 +148,7 @@ void DepthwiseConvGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
......@@ -167,13 +164,10 @@ void DepthwiseConvGradKernel(const Context& dev_ctx,
out_grad,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
input_grad,
filter_grad);
}
......
......@@ -27,12 +27,9 @@ void ConvKernel(const Context& dev_ctx,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
int groups,
const std::string& data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search,
DenseTensor* out) {
using XPUT = typename XPUTypeTrait<T>::Type;
std::vector<int> paddings = paddings_t;
......@@ -117,7 +114,7 @@ void DepthwiseConvKernel(const Context& dev_ctx,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
......@@ -131,13 +128,10 @@ void DepthwiseConvKernel(const Context& dev_ctx,
filter,
strides,
paddings,
paddding_algorithm,
groups,
padding_algorithm,
dilations,
groups,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search,
out);
}
......
......@@ -17,31 +17,15 @@
namespace phi {
KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (!ctx.HasAttr("use_addto") || !ctx.HasAttr("workspace_size_MB") ||
!ctx.HasAttr("exhaustive_search")) {
return KernelSignature("conv2d_infer",
{"Input", "Filter"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Output"});
} else {
return KernelSignature("conv2d",
{"Input", "Filter"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format",
"use_addto",
"workspace_size_MB",
"exhaustive_search"},
"groups",
"data_format"},
{"Output"});
}
}
KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......@@ -50,12 +34,9 @@ KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format",
"use_addto",
"workspace_size_MB",
"exhaustive_search"},
"groups",
"data_format"},
{"Input@GRAD", "Filter@GRAD"});
}
......@@ -66,12 +47,9 @@ KernelSignature Conv2dDoubleGradOpArgumentMapping(
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format",
"use_addto",
"workspace_size_MB",
"exhaustive_search"},
"groups",
"data_format"},
{"DInput", "DFilter", "DDOutput"});
}
......
......@@ -289,12 +289,9 @@ class Conv2D(layers.Layer):
self._stride,
self._padding,
"EXPLICIT",
self._groups if self._groups else 1,
self._dilation,
self._groups if self._groups else 1,
"NCHW",
False,
-1,
False,
)
if self.bias is not None:
pre_act = F.elementwise_add(pre_bias, self.bias, axis=1)
......
......@@ -137,12 +137,9 @@ def _conv_nd(
stride,
padding,
padding_algorithm,
groups,
dilation,
groups,
data_format,
False,
-1,
False,
)
if bias is not None:
channel_dim = (
......@@ -486,6 +483,18 @@ def conv1d(
x = unsqueeze(x, axis=[squeeze_aixs])
if in_dygraph_mode():
if l_type == 'conv2d':
out = _C_ops.conv2d(
x,
weight,
stride,
padding,
padding_algorithm,
dilation,
groups,
conv2d_data_format,
)
else:
out = getattr(_C_ops, l_type)(
x,
weight,
......@@ -746,12 +755,9 @@ def conv2d(
stride,
padding,
padding_algorithm,
groups,
dilation,
groups,
data_format,
False,
-1,
False,
)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册