未验证 提交 21448525 编写于 作者: J jakpiase 提交者: GitHub

[CHERRY-PICK] Reduce grad fix cherrypick (#32742)

* base changes for fix

* minor change

* fix for bwd kernel

* removed unnecessary import

* implemented reviewers suggestions

* CI fix
上级 9a589de8
......@@ -45,7 +45,8 @@ class ReduceMeanGradMKLDNNKernel : public ReduceGradMKLDNNKernel<T> {
number_of_elements = input_x->numel();
}
this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f,
this->RunKernel(ctx, dnnl::algorithm::binary_add,
dnnl::algorithm::reduction_mean, 0.0f,
1.0L / number_of_elements);
}
};
......
......@@ -21,6 +21,27 @@ using paddle::framework::LoDTensor;
using paddle::framework::Tensor;
using platform::to_void_cast;
inline std::vector<int64_t> CalculateReducedDims(const Tensor* input,
const Tensor* output,
std::vector<int>& reduce_dims,
bool reduce_all,
bool keep_dim) {
if (keep_dim) return framework::vectorize(output->dims());
if (reduce_all)
return std::vector<int64_t>(framework::vectorize(input->dims()).size(), 1);
std::vector<int64_t> output_dims(framework::vectorize(input->dims()));
for (size_t i = 0; i < reduce_dims.size(); ++i) {
reduce_dims[i] = (reduce_dims[i] >= 0)
? reduce_dims[i]
: input->dims().size() + reduce_dims[i];
output_dims[reduce_dims[i]] = 1;
}
return output_dims;
}
template <typename T>
class ReduceMKLDNNKernel : public framework::OpKernel<T> {
public:
......@@ -37,9 +58,8 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
bool reduce_all = ctx.Attr<bool>("reduce_all");
bool keep_dim = ctx.Attr<bool>("keep_dim");
std::vector<int64_t> output_dims =
CalculateOutputDims(input, output, reduce_dims, reduce_all, keep_dim);
auto output_dims =
CalculateReducedDims(input, output, reduce_dims, reduce_all, keep_dim);
auto input_dims = framework::vectorize(input->dims());
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
......@@ -96,53 +116,63 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
paddle::framework::vectorize<int64_t>(output->dims()))));
}
}
private:
std::vector<int64_t> CalculateOutputDims(const Tensor* input,
const Tensor* output,
std::vector<int>& reduce_dims,
bool reduce_all,
bool keep_dim) const {
if (keep_dim) return framework::vectorize(output->dims());
if (reduce_all)
return std::vector<int64_t>(framework::vectorize(input->dims()).size(),
1);
std::vector<int64_t> output_dims(framework::vectorize(input->dims()));
for (size_t i = 0; i < reduce_dims.size(); ++i) {
reduce_dims[i] = (reduce_dims[i] >= 0)
? reduce_dims[i]
: input->dims().size() + reduce_dims[i];
output_dims[reduce_dims[i]] = 1;
}
return output_dims;
}
};
template <typename T>
class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void RunKernel(const framework::ExecutionContext& ctx,
dnnl::algorithm binary_type, float scale_x,
float scale_y) const {
dnnl::algorithm binary_type, dnnl::algorithm reduction_type,
float scale_x, float scale_y) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
bool keep_dim = ctx.Attr<bool>("keep_dim");
bool reduce_all = ctx.Attr<bool>("reduce_all");
auto dims = ctx.Attr<std::vector<int>>("dim");
auto* input_dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* output_dx = ctx.Output<Tensor>(framework::GradVarName("X"));
mkldnn::memory::format_tag x_format_tag;
auto input_dims =
CalculateReducedDims(output_dx, input_dy, dims, reduce_all, keep_dim);
if (input_dims != framework::vectorize(output_dx->dims())) {
const std::string key_pd =
platform::CreateKey(
dev_ctx, framework::vectorize(output_dx->dims()),
ctx.InputName("X"),
(std::to_string(static_cast<int>(reduction_type)))) +
"@fwd_pd";
std::shared_ptr<dnnl::reduction::primitive_desc> fwd_pd =
std::static_pointer_cast<dnnl::reduction::primitive_desc>(
dev_ctx.GetBlob(key_pd));
PADDLE_ENFORCE_NOT_NULL(
fwd_pd, platform::errors::Unavailable(
"Forward primitive descriptor is not available in %s op, "
"cannot deduce memory format tag",
ctx.Type()));
x_format_tag = platform::GetMKLDNNFormat(fwd_pd->src_desc());
PADDLE_ENFORCE_NE(x_format_tag, mkldnn::memory::format_tag::undef,
platform::errors::InvalidArgument(
"Cannot deduce format tag for %s op", ctx.Type()));
} else { // fwd descriptor not available because reorder was used instead
// of reduction
x_format_tag = getPlainFormatTag(output_dx);
}
output_dx->mutable_data<T>(ctx.GetPlace());
output_dx->set_format(getPlainFormatTag(output_dx));
output_dx->set_format(x_format_tag);
output_dx->set_layout(input_dy->layout());
platform::BroadcastDataMKLDNNHandler<T> handler(
binary_type, dev_ctx, onednn_engine, ctx.GetPlace(), output_dx,
input_dy, scale_x, scale_y,
ctx.InputName(framework::GradVarName("Out")));
ctx.InputName(framework::GradVarName("Out")), input_dims);
const auto src_dx_memory = handler.AcquireSrcMemory(output_dx);
const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy);
......
......@@ -29,7 +29,8 @@ template <typename T>
class ReduceSumGradMKLDNNKernel : public ReduceGradMKLDNNKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f, 1.0f);
this->RunKernel(ctx, dnnl::algorithm::binary_add,
dnnl::algorithm::reduction_sum, 0.0f, 1.0f);
}
};
......
......@@ -559,8 +559,11 @@ class ReduceGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
int in_dtype = ctx.Attr<int>("in_dtype");
auto input_data_type =
(in_dtype >= 0) ? static_cast<framework::proto::VarType::Type>(in_dtype)
: OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
auto CanMKLDNNReduceGradBeUsed = [&]() {
......@@ -568,18 +571,6 @@ class ReduceGradOp : public framework::OperatorWithKernel {
if (dx_dims.size() > 5) return false; // max 5D tensor is supported
if (ctx.Attr<bool>("reduce_all") ||
((int)ctx.Attr<std::vector<int>>("dim").size() == dx_dims.size()))
return true;
auto dy_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
// Subtensor must be on rightmost part of the bigger tensor
for (int i = 0; i < dy_dims.size(); ++i) {
if (dx_dims[dx_dims.size() - dy_dims.size() + i] != dy_dims[i]) {
return false;
}
}
return true;
};
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
......@@ -590,12 +581,6 @@ class ReduceGradOp : public framework::OperatorWithKernel {
}
#endif
int in_dtype = ctx.Attr<int>("in_dtype");
if (in_dtype >= 0) {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(in_dtype),
ctx.GetPlace());
}
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -639,7 +639,8 @@ class BroadcastDataMKLDNNHandler
const mkldnn::engine engine,
platform::Place cpu_place, const Tensor* x,
const Tensor* y, float scale_x, float scale_y,
const std::string& uniq_name)
const std::string& uniq_name,
std::vector<int64_t>& input_dims)
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
......@@ -659,24 +660,12 @@ class BroadcastDataMKLDNNHandler
y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Y tensor."));
auto src1_tz = framework::vectorize(y->dims());
const auto src0_tz = framework::vectorize(x->dims());
// GetExpectedKernelType checks if smaller vector is a subvector with all
// the dims in correct order on the rightmost part of the bigger vector,
// i.e. a correct vector for broadcasting:
// x = 5, 7, 3, 2, 4, 8
// y = 4, 8
src1_tz.reserve(src0_tz.size());
for (size_t i = src1_tz.size(); i < src0_tz.size(); ++i) {
src1_tz.insert(src1_tz.begin(), 1L);
}
const auto src0_md = dnnl::memory::desc(
src0_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto src1_md = dnnl::memory::desc(
src1_tz, platform::MKLDNNGetDataType<T>(), x->format());
input_dims, platform::MKLDNNGetDataType<T>(), x->format());
dnnl::primitive_attr attributes;
attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
......@@ -711,7 +700,7 @@ class ReductionMKLDNNHandler
const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y,
const std::string& uniq_name,
std::vector<int64_t> output_dims)
std::vector<int64_t> y_tz)
: platform::MKLDNNHandlerT<T, dnnl::reduction>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
......@@ -725,14 +714,14 @@ class ReductionMKLDNNHandler
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor."));
const auto src_tz = framework::vectorize(x->dims());
const auto x_tz = framework::vectorize(x->dims());
const auto src_md = dnnl::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto dst_md = memory::desc(
output_dims, platform::MKLDNNGetDataType<T>(), x->format());
const auto x_md = dnnl::memory::desc(
x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto y_md =
memory::desc(y_tz, platform::MKLDNNGetDataType<T>(), x->format());
this->AcquireForwardPrimitiveDescriptor(algo, src_md, dst_md, p, eps);
this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册