未验证 提交 0a5c99e8 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Fix to issue #34554 (#34623)

* - Added softmax without caching

* - Binary is no longer manually cached

* - Activation onednn caching removed

* - Removed manual caching of activation

* - modified UT

* - fix

* - fix

* - fixes to building

* - fix

* - fix

* - fix to UT

* - Faulty UT workaround

* - approval workaround

* - Fixes after review

* - compilation fixes

* - more lint fixes

* - more fixes after review

* - fixes after another round of review
上级 99f8f5c8
...@@ -47,13 +47,24 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -47,13 +47,24 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
float scale_o = ctx.Attr<float>("Scale_out"); float scale_o = ctx.Attr<float>("Scale_out");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
platform::BinaryMKLDNNHandler<T> handler( platform::BinaryMKLDNNHandler<T> handler(BINARY_OP, axis, mkldnn_engine,
BINARY_OP, axis, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z, ctx.GetPlace(), x, y, z, scale_x,
scale_x, scale_y, scale_o, ctx.OutputName("Out")); scale_y, scale_o);
const auto src_x_memory = handler.AcquireSrcMemory(x); const auto src_x_memory = handler.AcquireSrcMemory(x);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y); const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_memory = handler.AcquireDstMemory(z); // (jczaja) For Inplace src and dst should be the same memory object.
// So x should share buffer with z. But UT mechanics is testing inplace
// execution for this op not checking that x can be bradcasted to match in
// shape y tensor.
// This is wrong as when x is to be broadcasted then z(out) will match the
// shape of y which is bigger than x. Hence if x is smaller in shape than z
// and they share a buffer (of
// shape x) then this buffer is not big enough to hold result of elementwise
// operation.
auto dst_memory = (x->numel() == z->numel() && x->IsSharedBufferWith(*z))
? src_x_memory
: handler.AcquireDstMemory(z);
const auto binary_prim = handler.AcquireForwardPrimitive(); const auto binary_prim = handler.AcquireForwardPrimitive();
......
...@@ -48,9 +48,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -48,9 +48,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
if (dx) { if (dx) {
// dx = dout*y // dx = dout*y
platform::BinaryMKLDNNHandler<T> handler( platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine, dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
ctx.GetPlace(), dout, y, dx, 1.0f, 1.0f, 1.0f, dout, y, dx, 1.0f, 1.0f, 1.0f);
ctx.InputName(framework::GradVarName("Out")));
const auto src_dout_memory = handler.AcquireSrcMemory(dout); const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y); const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
...@@ -75,9 +74,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -75,9 +74,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Handler is having nullptr passed instead of output tensor as // Handler is having nullptr passed instead of output tensor as
// we want Dst buffer to be allocated by oneDNN not to use Tensor // we want Dst buffer to be allocated by oneDNN not to use Tensor
platform::BinaryMKLDNNHandler<T> handler( platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine, dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
ctx.GetPlace(), dout, x, nullptr, 1.0f, 1.0f, 1.0f, dout, x, nullptr, 1.0f, 1.0f, 1.0f);
ctx.InputName(framework::GradVarName("Out")));
const auto src_dout_memory = handler.AcquireSrcMemory(dout); const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_x_memory = handler.AcquireSecondSrcMemory(x); const auto src_x_memory = handler.AcquireSecondSrcMemory(x);
......
...@@ -79,15 +79,15 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -79,15 +79,15 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL eletwise_forward must use CPUPlace")); "Operator DNNL eletwise_forward must use CPUPlace"));
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out"); auto *y = ctx.Output<Tensor>("Out");
bool is_inplaced = x->IsSharedBufferWith(*y); bool is_inplaced = x->IsSharedBufferWith(*y);
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, dev_ctx, platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), x, ctx.GetPlace(), x);
ctx.InputName("X"), is_inplaced);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y); auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
...@@ -106,13 +106,14 @@ template <typename T> ...@@ -106,13 +106,14 @@ template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx, void eltwise_grad(const framework::ExecutionContext &ctx,
mkldnn::algorithm algorithm) { mkldnn::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler( platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
algorithm, ctx, dev_ctx, ctx.GetPlace(), x, diff_y, ctx.InputName("X")); ctx.GetPlace(), x, diff_y);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x); auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
......
cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_mul_op elementwise_add_op activation_op softmax_op softmax scope device_context enforce) cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_mul_op elementwise_add_op activation_op softmax_op conv_op im2col vol2col softmax scope device_context enforce)
...@@ -29,6 +29,7 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> { ...@@ -29,6 +29,7 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
void RunKernel(const framework::ExecutionContext& ctx) const { void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx = const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
...@@ -36,11 +37,12 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> { ...@@ -36,11 +37,12 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
bool is_inplaced = x->IsSharedBufferWith(*out); bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler( platform::ActivationMKLDNNHandler<T> handler(
mkldnn::algorithm::eltwise_linear, ctx, dev_ctx, ctx.GetPlace(), x, mkldnn::algorithm::eltwise_linear, ctx, mkldnn_engine, ctx.GetPlace(),
ctx.InputName("X"), is_inplaced); x);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out); auto dst_memory_p =
is_inplaced ? src_memory_p : handler.AcquireDstMemory(out);
auto activation_p = handler.AcquireForwardPrimitive(); auto activation_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
...@@ -32,69 +32,56 @@ using platform::to_void_cast; ...@@ -32,69 +32,56 @@ using platform::to_void_cast;
template <typename T> template <typename T>
class SoftmaxMKLDNNHandler class SoftmaxMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : public platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward> { mkldnn::softmax_backward> {
public: public:
SoftmaxMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, SoftmaxMKLDNNHandler(const mkldnn::engine mkldnn_engine,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input, platform::Place cpu_place, const Tensor* input,
Tensor* output, const int axis, Tensor* output, const int axis)
const std::string uniq_name, bool is_inplaced) : platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, mkldnn::softmax_backward>(
mkldnn::softmax_backward>( mkldnn_engine, cpu_place) {
dev_ctx, mkldnn_engine, cpu_place, PADDLE_ENFORCE_EQ(
// Softmax may be inplace then uniq_name is no longer unique input->dims(), output->dims(),
is_inplaced ? platform::CreateKey( platform::errors::InvalidArgument(
dev_ctx, framework::vectorize(input->dims()), "The shape of input and output tensor must be identical."));
axis, uniq_name)
: platform::CreateKey( auto softmax_tz = framework::vectorize(input->dims());
dev_ctx, framework::vectorize(input->dims()), auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
uniq_name)) { input->format());
if (!this->isCached()) {
PADDLE_ENFORCE_EQ( this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
input->dims(), output->dims(), axis);
platform::errors::InvalidArgument(
"The shape of input and output tensor must be identical."));
auto softmax_tz = framework::vectorize(input->dims());
auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
axis);
}
} }
SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx, SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* out, platform::Place cpu_place, const Tensor* out,
const Tensor* out_grad, Tensor* in_x_grad, const Tensor* out_grad, Tensor* in_x_grad,
const std::string& unique_name) const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>( mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, mkldnn_engine, cpu_place) {
platform::CreateKey(dev_ctx, framework::vectorize(out->dims()), PADDLE_ENFORCE_EQ(out_grad->dims(), in_x_grad->dims(),
unique_name)) { platform::errors::InvalidArgument(
if (!this->isBwdCached()) { "The shape of softmax_grad's input "
PADDLE_ENFORCE_EQ( "and output must be identical, but shapes differ, "
out_grad->dims(), in_x_grad->dims(), "out_grad: %s in_grad: %s",
platform::errors::InvalidArgument("The shape of softmax_grad's input " out_grad->dims(), in_x_grad->dims()));
"and output must be identical."));
auto dims = out_grad->dims(); // input and output share the same shape
auto dims = out_grad->dims(); // input and output share the same shape const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size()); auto softmax_tz = framework::vectorize<int64_t>(dims);
auto softmax_tz = framework::vectorize<int64_t>(dims);
auto data_softmax_md = MKLDNNMemDesc(
auto data_softmax_md = MKLDNNMemDesc( softmax_tz, platform::MKLDNNGetDataType<T>(), out->format());
softmax_tz, platform::MKLDNNGetDataType<T>(), out->format()); auto diff_softmax_md = MKLDNNMemDesc(
auto diff_softmax_md = MKLDNNMemDesc( softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, data_softmax_md, axis);
data_softmax_md, axis); this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md, axis);
axis);
}
} }
}; };
...@@ -111,9 +98,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -111,9 +98,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size()); const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(), SoftmaxMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), input,
input, output, axis, ctx.OutputName("Out"), output, axis);
is_inplaced);
auto softmax_src_memory_p = handler.AcquireSrcMemory(input); auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
// For Inplace src and and dst are the same memory object // For Inplace src and and dst are the same memory object
...@@ -149,11 +135,12 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -149,11 +135,12 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL SoftmaxGrad must use CPUPlace")); "Operator DNNL SoftmaxGrad must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const Tensor* output = ctx.Input<Tensor>("Out"); const Tensor* output = ctx.Input<Tensor>("Out");
auto* out_grad = ctx.template Input<Tensor>(framework::GradVarName("Out")); auto* out_grad = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* in_x_grad = ctx.template Output<Tensor>(framework::GradVarName("X")); auto* in_x_grad = ctx.template Output<Tensor>(framework::GradVarName("X"));
SoftmaxMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), output, SoftmaxMKLDNNHandler<T> handler(ctx, mkldnn_engine, ctx.GetPlace(), output,
out_grad, in_x_grad, ctx.InputName("Out")); out_grad, in_x_grad, ctx.InputName("Out"));
auto dst_memory_p = handler.AcquireDstMemory(output); auto dst_memory_p = handler.AcquireDstMemory(output);
......
...@@ -33,6 +33,8 @@ USE_OP(relu); ...@@ -33,6 +33,8 @@ USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP(softmax); USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(conv2d);
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -64,16 +66,19 @@ class CacheTester { ...@@ -64,16 +66,19 @@ class CacheTester {
template <typename T> template <typename T>
void RunOperator(const platform::Place &place, const std::string &op_type, void RunOperator(const platform::Place &place, const std::string &op_type,
const framework::DDim &dims, const std::string &output_name, const framework::DDim &dims, const std::string &first_input) {
bool inplace = false) {
framework::Scope scope; framework::Scope scope;
std::map<const std::string, int> num_inputs = {{"softmax", 1}, std::map<const std::string, int> num_inputs = {{"softmax", 1},
{"relu", 1}, {"relu", 1},
{"conv2d", 2},
{"elementwise_add", 2}, {"elementwise_add", 2},
{"elementwise_mul", 2}}; {"elementwise_mul", 2}};
std::string first_input = inplace == true ? output_name : "x"; std::string first_input_var_name = (op_type == "conv2d") ? "Input" : "X";
std::string second_input_var_name = (op_type == "conv2d") ? "Filter" : "Y";
std::string output_var_name = (op_type == "conv2d") ? "Output" : "Out";
std::string output_name = "output";
std::vector<InputVars> input_names = { std::vector<InputVars> input_names = {
{first_input, scope.Var(first_input)->GetMutable<framework::LoDTensor>()}, {first_input, scope.Var(first_input)->GetMutable<framework::LoDTensor>()},
...@@ -113,71 +118,40 @@ void RunOperator(const platform::Place &place, const std::string &op_type, ...@@ -113,71 +118,40 @@ void RunOperator(const platform::Place &place, const std::string &op_type,
auto &pool = platform::DeviceContextPool::Instance(); auto &pool = platform::DeviceContextPool::Instance();
auto op = num_inputs[op_type] > 1 auto op =
? framework::OpRegistry::CreateOp( num_inputs[op_type] > 1
op_type, {{"X", {first_input}}, {"Y", {"x1"}}}, ? framework::OpRegistry::CreateOp(
{{"Out", {output_name}}}, {{"use_mkldnn", {true}}}) op_type, {{first_input_var_name, {first_input}},
: framework::OpRegistry::CreateOp( {second_input_var_name, {"x1"}}},
op_type, {{"X", {first_input}}}, {{"Out", {output_name}}}, {{output_var_name, {output_name}}}, {{"use_mkldnn", {true}}})
{{"use_mkldnn", {true}}}); : framework::OpRegistry::CreateOp(
op_type, {{first_input_var_name, {first_input}}},
{{output_var_name, {output_name}}}, {{"use_mkldnn", {true}}});
op->Run(scope, place); op->Run(scope, place);
pool.Get(place)->Wait(); pool.Get(place)->Wait();
} }
TEST(test_softmax_reuse_cache, cpu_place) { TEST(test_conv2d_reuse_cache, cpu_place) {
framework::DDim dims({32, 64}); framework::DDim dims({1, 16, 32, 64});
platform::CPUPlace p; platform::CPUPlace p;
CacheTester ct; CacheTester ct;
RunOperator<float>(p, "softmax", dims, "softmax_out"); RunOperator<float>(p, "conv2d", dims, "input_signal");
RunOperator<float>(p, "softmax", dims, "softmax_out"); RunOperator<float>(p, "conv2d", dims, "input_signal");
PADDLE_ENFORCE_EQ(ct.Analyze(4), true, PADDLE_ENFORCE_EQ(ct.Analyze(9), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects")); "Invalid number of cached oneDNN objects"));
} }
TEST(test_softmax_noreuse_cache, cpu_place) { TEST(test_conv2d_noreuse_cache, cpu_place) {
framework::DDim dims({32, 64}); framework::DDim dims({1, 16, 32, 64});
platform::CPUPlace p; platform::CPUPlace p;
CacheTester ct; CacheTester ct;
RunOperator<float>(p, "softmax", dims, "softmax_out"); RunOperator<float>(p, "conv2d", dims, "input_signal");
RunOperator<float>(p, "softmax", dims, "softmax_out2"); RunOperator<float>(p, "conv2d", dims, "input_signal2");
PADDLE_ENFORCE_EQ(ct.Analyze(8), true, PADDLE_ENFORCE_EQ(ct.Analyze(18), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects")); "Invalid number of cached oneDNN objects"));
}
TEST(test_softmax_inplace_cache, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "softmax", dims, "softmax_out");
RunOperator<float>(p, "softmax", dims, "softmax_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(7), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
}
TEST(test_relu_inplace_cache, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "relu", dims, "relu_out");
RunOperator<float>(p, "relu", dims, "relu_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(7), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
}
TEST(test_elementwise_add_reuse_cache, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "elementwise_add", dims, "elementwise_add_out");
RunOperator<float>(p, "relu", dims, "elementwise_add_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(8), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
} }
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册