未验证 提交 66a10f36 编写于 作者: J jakpiase 提交者: GitHub

[Video detection] Added fill_constant FP32 FWD oneDNN kernel (#37216)

* added fill_constant kernel

* CI fix

* ci fix

* switched from nan to zero memory

* CI FIX

* ci fixes

* CI rerun

* ci fix

* minor change

* CI rerun
上级 4892d592
...@@ -58,4 +58,8 @@ extra { ...@@ -58,4 +58,8 @@ extra {
name: "op_device" name: "op_device"
type: STRING type: STRING
} }
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
} }
...@@ -22,10 +22,10 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -22,10 +22,10 @@ class FillConstantOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "FillConstant"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "FillConstant");
auto& shape = ctx->Attrs().Get<std::vector<int64_t>>("shape"); auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) { if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) {
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
...@@ -52,8 +52,8 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -52,8 +52,8 @@ class FillConstantOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor, const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType& expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") { if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") {
return expected_kernel_type; return expected_kernel_type;
} else { } else {
...@@ -63,7 +63,7 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -63,7 +63,7 @@ class FillConstantOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType( framework::OpKernelType kt = framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace()); ctx.GetPlace());
...@@ -97,13 +97,24 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -97,13 +97,24 @@ class FillConstantOp : public framework::OperatorWithKernel {
} }
} }
#ifdef PADDLE_WITH_MKLDNN
auto input_data_type =
framework::proto::VarType::Type(ctx.Attr<int>("dtype"));
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return kt; return kt;
} }
}; };
class FillConstantOpVarTypeInference : public framework::VarTypeInference { class FillConstantOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>( auto data_type = static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, ctx->GetAttr("dtype"))); BOOST_GET_CONST(int, ctx->GetAttr("dtype")));
ctx->SetOutputDataType("Out", data_type); ctx->SetOutputDataType("Out", data_type);
...@@ -156,6 +167,10 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -156,6 +167,10 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
"3: XPUPlace. " "3: XPUPlace. "
"4: NPUPlace. ") "4: NPUPlace. ")
.SetDefault(-1); .SetDefault(-1);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddOutput("Out", AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled " "(Tensor) Tensor of specified shape will be filled "
"with the specified value"); "with the specified value");
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T>
class FillConstantMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public:
FillConstantMKLDNNHandler(Tensor* out, dnnl::engine engine,
platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
const auto src0_md = dnnl::memory::desc(
{out->numel(), sizeof(T)}, platform::MKLDNNGetDataType<uint8_t>(),
dnnl::memory::format_tag::ab);
dnnl::primitive_attr attrs;
attrs.set_scales(DNNL_ARG_SRC_0, /* mask = */ 0, {0.0f});
this->AcquireForwardPrimitiveDescriptor(attrs, dnnl::algorithm::binary_add,
src0_md, src1_md, src0_md);
}
static const dnnl::memory::desc src1_md;
};
template <typename T>
const dnnl::memory::desc FillConstantMKLDNNHandler<T>::src1_md(
{1, sizeof(T)}, platform::MKLDNNGetDataType<uint8_t>(),
dnnl::memory::format_tag::ab);
template <typename T>
class FillConstantMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& dnnl_engine = dev_ctx.GetEngine();
auto* out = ctx.Output<Tensor>("Out");
T fill_value = CalculateFillValue(ctx);
auto shape = GetShape(ctx);
out->Resize(shape);
FillConstantMKLDNNHandler<T> handler(out, dnnl_engine, ctx.GetPlace());
dnnl::memory constant_value_memory =
dnnl::memory(FillConstantMKLDNNHandler<T>::src1_md, dnnl_engine,
reinterpret_cast<uint8_t*>(&fill_value));
auto src0_memory_p = handler.AcquireDstMemory(out);
auto fill_constant_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
fill_constant_p->execute(astream, {{DNNL_ARG_SRC_0, *src0_memory_p},
{DNNL_ARG_SRC_1, constant_value_memory},
{DNNL_ARG_DST, *src0_memory_p}});
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetPlainMKLDNNFormat(out->dims().size()));
}
T CalculateFillValue(const framework::ExecutionContext& ctx) const {
const auto str_value = ctx.Attr<std::string>("str_value");
const auto float_value = ctx.Attr<float>("value");
T value;
if (str_value.empty()) {
value = static_cast<T>(float_value);
} else {
// handle NaN/Inf first, which cannot be read from stream
if (str_value == "inf") {
value = static_cast<T>(std::numeric_limits<float>::infinity());
} else if (str_value == "-inf") {
value = static_cast<T>(-std::numeric_limits<float>::infinity());
} else if (str_value == "nan") {
value = static_cast<T>(std::numeric_limits<float>::quiet_NaN());
} else {
std::stringstream convert_stream(str_value);
double tmp_value;
convert_stream >> tmp_value;
value = static_cast<T>(tmp_value);
}
}
if (ctx.HasInput("ValueTensor")) {
const auto* value_tensor = ctx.Input<Tensor>("ValueTensor");
PADDLE_ENFORCE_EQ(
value_tensor->numel(), 1,
platform::errors::InvalidArgument(
"When use Tensor as value to set Tensor value in fill_constant, "
"value input(ValueTensor) size must be 1, but got %d",
value_tensor->numel()));
value = value_tensor->data<T>()[0];
}
return value;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(fill_constant, MKLDNN, paddle::platform::CPUPlace,
ops::FillConstantMKLDNNKernel<float>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool
import paddle
@OpTestTool.skip_if_not_cpu_bf16()
class TestFillConstant2DOneDNNOp(OpTest):
def setUp(self):
self.op_type = "fill_constant"
self.dtype = np.float32
self.shape_tensor_list = None
self.shape_tensor = None
self.str_value = ""
real_shape = []
self.value = 0.1
self.set_inputs()
self.set_attrs()
if 'value' in self.attrs:
self.value = self.attrs['value']
if self.str_value != "":
self.value = float(self.str_value)
if 'ValueTensor' in self.inputs:
self.value = self.inputs['ValueTensor']
if 'shape' in self.attrs:
real_shape = self.attrs['shape']
if 'ShapeTensor' in self.inputs:
real_shape = list(self.inputs['ShapeTensor'])
if 'ShapeTensorList' in self.inputs:
real_shape = []
for shape_tensor in self.inputs['ShapeTensorList']:
real_shape.append(shape_tensor[1].item())
self.outputs = {'Out': np.full(real_shape, self.value)}
def set_inputs(self):
self.inputs = {}
def set_attrs(self):
self.attrs = {'shape': (3, 5), 'use_mkldnn': True, 'value': self.value}
def test_check_output(self):
self.check_output()
class TestFillZerosLike4DShapeTensorPriorityOneDNNOp(
TestFillConstant2DOneDNNOp):
def set_inputs(self):
self.inputs = {'ShapeTensor': np.array([5, 6, 7, 8]).astype("int32")}
class TestFillZerosLike4DShapeTensorListPriorityOneDNNOp(
TestFillConstant2DOneDNNOp):
def set_inputs(self):
shape = (4, 5, 6, 7)
self.shape_tensor_list = []
for index, elem in enumerate(shape):
self.shape_tensor_list.append(("x" + str(index), np.ones(
(1)).astype('int32') * elem))
self.inputs = {'ShapeTensorList': self.shape_tensor_list}
class TestFillZerosLike2DStringValueInfOneDNNOp(TestFillConstant2DOneDNNOp):
def set_attrs(self):
self.str_value = "inf"
self.attrs = {'shape': (10, 13), 'use_mkldnn': True, 'str_value': "inf"}
class TestFillZerosLike2DStringValueMinusInfOneDNNOp(
TestFillConstant2DOneDNNOp):
def set_attrs(self):
self.str_value = "-inf"
self.attrs = {
'shape': (10, 13),
'use_mkldnn': True,
'str_value': "-inf"
}
class TestFillZerosLike2DStringValueFloatOneDNNOp(TestFillConstant2DOneDNNOp):
def set_attrs(self):
self.str_value = "0.123"
self.attrs = {
'shape': (10, 13),
'use_mkldnn': True,
'str_value': "0.123"
}
class TestFillZerosLike2DValueTensorPriorityOneDNNOp(
TestFillZerosLike2DStringValueFloatOneDNNOp):
def set_inputs(self):
self.inputs = {'ValueTensor': np.atleast_1d(2.25).astype("float32")}
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册