未验证 提交 905b0765 编写于 作者: X xiaoting 提交者: GitHub

rm max_input in conv2d for kunlun, test=kunlun (#28063)

上级 8600f474
...@@ -27,10 +27,10 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -27,10 +27,10 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
// that avoids modifying the variable in the Scope. // that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Output");
Tensor* max_input = context.Output<Tensor>("MaxInput"); // Tensor* max_input = context.Output<Tensor>("MaxInput");
Tensor* max_filter = context.Output<Tensor>("MaxFilter"); // Tensor* max_filter = context.Output<Tensor>("MaxFilter");
max_input->mutable_data<T>(context.GetPlace()); // max_input->mutable_data<T>(context.GetPlace());
max_filter->mutable_data<T>(context.GetPlace()); // max_filter->mutable_data<T>(context.GetPlace());
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups"); int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
...@@ -47,28 +47,28 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -47,28 +47,28 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
dilations[0] == 1 && dilations[1] == 1, true, dilations[0] == 1 && dilations[1] == 1, true,
platform::errors::InvalidArgument("XPU only support dilation == 1.")); platform::errors::InvalidArgument("XPU only support dilation == 1."));
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
PADDLE_ENFORCE_EQ( // PADDLE_ENFORCE_EQ(
xpu::findmax(dev_ctx.x_context(), input->data<T>(), input->numel(), // xpu::findmax(dev_ctx.x_context(), input->data<T>(), input->numel(),
max_input->data<T>()) == xpu::Error_t::SUCCESS, // max_input->data<T>()) == xpu::Error_t::SUCCESS,
true, platform::errors::InvalidArgument( // true, platform::errors::InvalidArgument(
"XPU conv kernel error,can not finde max_input,please " // "XPU conv kernel error,can not finde max_input,please "
"check whether Baidu Kunlun " // "check whether Baidu Kunlun "
"Card is properly installed.")); // "Card is properly installed."));
PADDLE_ENFORCE_EQ( // PADDLE_ENFORCE_EQ(
xpu::findmax(dev_ctx.x_context(), filter.data<T>(), filter.numel(), // xpu::findmax(dev_ctx.x_context(), filter.data<T>(), filter.numel(),
max_filter->data<T>()) == xpu::Error_t::SUCCESS, // max_filter->data<T>()) == xpu::Error_t::SUCCESS,
true, platform::errors::InvalidArgument( // true, platform::errors::InvalidArgument(
"XPU conv kernel error,can not find max_filter,please " // "XPU conv kernel error,can not find max_filter,please "
"check whether Baidu Kunlun " // "check whether Baidu Kunlun "
"Card is properly installed.")); // "Card is properly installed."));
if (groups == 1) { if (groups == 1) {
int r = xpu::conv2d_forward_int16<float, float, float, float>( int r = xpu::conv2d_forward_int16<float, float, float, float>(
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w, dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
strides[0], strides[1], paddings[0], paddings[1], dilations[0], strides[0], strides[1], paddings[0], paddings[1], dilations[0],
dilations[1], groups, input->data<float>(), filter.data<float>(), dilations[1], groups, input->data<float>(), filter.data<float>(),
output->data<float>(), nullptr, nullptr, xpu::Activation_t::LINEAR, output->data<float>(), nullptr, nullptr, xpu::Activation_t::LINEAR,
// nullptr, nullptr); nullptr, nullptr);
max_input->data<float>(), max_filter->data<float>()); // max_input->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], " platform::errors::External("XPU conv kernel return wrong value[%d], "
...@@ -80,8 +80,8 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -80,8 +80,8 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
dev_ctx.x_context(), input->data<float>(), filter.data<float>(), dev_ctx.x_context(), input->data<float>(), filter.data<float>(),
output->data<float>(), batch_size, img_c, img_h, img_w, f, win_h, output->data<float>(), batch_size, img_c, img_h, img_w, f, win_h,
win_w, groups, strides[0], strides[1], paddings[0], paddings[1], win_w, groups, strides[0], strides[1], paddings[0], paddings[1],
// nullptr, nullptr); nullptr, nullptr);
max_input->data<float>(), max_filter->data<float>()); // max_input->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], " platform::errors::External("XPU conv kernel return wrong value[%d], "
...@@ -96,9 +96,9 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -96,9 +96,9 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
const Tensor* max_input = context.Input<Tensor>("MaxInput"); // const Tensor* max_input = context.Input<Tensor>("MaxInput");
const Tensor* max_filter = context.Input<Tensor>("MaxFilter"); // const Tensor* max_filter = context.Input<Tensor>("MaxFilter");
Tensor* max_output_grad = context.Output<Tensor>("MaxOutputGrad"); // Tensor* max_output_grad = context.Output<Tensor>("MaxOutputGrad");
const Tensor* output_grad = const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output")); context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = Tensor* input_grad =
...@@ -133,25 +133,25 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -133,25 +133,25 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
} }
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
max_output_grad->Resize({4}); // max_output_grad->Resize({4});
max_output_grad->mutable_data<T>(context.GetPlace()); // max_output_grad->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_EQ( // PADDLE_ENFORCE_EQ(
xpu::findmax(dev_ctx.x_context(), output_grad->data<T>(), // xpu::findmax(dev_ctx.x_context(), output_grad->data<T>(),
output_grad->numel(), // output_grad->numel(),
max_output_grad->data<T>()) == xpu::Error_t::SUCCESS, // max_output_grad->data<T>()) == xpu::Error_t::SUCCESS,
true, // true,
platform::errors::External( // platform::errors::External(
"XPU conv kernel error, can not find max_output_grad, please check " // "XPU conv kernel error, can not find max_output_grad, please
"whether Baidu Kunlun Card is " // check "
"properly installed.")); // "whether Baidu Kunlun Card is "
// "properly installed."));
if (input_grad) { if (input_grad) {
int r = xpu::conv2d_backward_int16( int r = xpu::conv2d_backward_int16(
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w, dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
strides[0], strides[1], paddings[0], paddings[1], dilations[0], strides[0], strides[1], paddings[0], paddings[1], dilations[0],
dilations[1], groups, output_grad->data<float>(), dilations[1], groups, output_grad->data<float>(),
filter.data<float>(), input_grad->data<float>(), filter.data<float>(), input_grad->data<float>(), nullptr, nullptr);
// nullptr, nullptr, // max_output_grad->data<float>(), max_filter->data<float>());
max_output_grad->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], " platform::errors::External("XPU conv kernel return wrong value[%d], "
...@@ -164,9 +164,8 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -164,9 +164,8 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w, dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
strides[0], strides[1], paddings[0], paddings[1], dilations[0], strides[0], strides[1], paddings[0], paddings[1], dilations[0],
dilations[1], groups, output_grad->data<float>(), dilations[1], groups, output_grad->data<float>(),
input->data<float>(), filter_grad->data<float>(), input->data<float>(), filter_grad->data<float>(), nullptr, nullptr);
// nullptr, nullptr, // max_output_grad->data<float>(), max_input->data<float>());
max_output_grad->data<float>(), max_input->data<float>());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], " platform::errors::External("XPU conv kernel return wrong value[%d], "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册