提交 68183f92 编写于 作者: X xiebaiyuan 提交者: GitHub

suite scale in nearest_interp_op ,test=mobile (#2633)

* suite scale in nearest_interp_op ,test=mobile

* suite scale in nearest_interp_op && fix grid_sampler_kernel.cl ,test=mobile

* suite scale in nearest_interp_op && fix grid_sampler_kernel.cl && fix code style ,test=mobile
上级 489a18b1
......@@ -28,8 +28,8 @@ __kernel void grid_sampler(__private const int out_height,
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int x_grid = out_h / 4 * 2;
int y_grid = out_n * out_width + out_w;
float4 g1 = read_imagef(grid, sampler, int2(x_grid, y_grid));
float4 g2 = read_imagef(grid, sampler, int2(x_grid + 1, y_grid));
float4 g1 = read_imagef(grid, sampler, (int2)(x_grid, y_grid));
float4 g2 = read_imagef(grid, sampler, (int2)(x_grid + 1, y_grid));
float x = (g1.x + 1) * (out_width - 1) / 2;
float y = (g2.x + 1) * (out_height - 1) / 2;
......@@ -39,15 +39,15 @@ __kernel void grid_sampler(__private const int out_height,
int y_p = out_n * out_height + y0;
int x_out = out_c * out_width + out_w;
int y_out = out_n * out_height + out_h;
float4 input0 = read_imagef(input, sampler, int2(x_p, y_p));
float4 input1 = read_imagef(input, sampler, int2(x_p + 1, y_p));
float4 input2 = read_imagef(input, sampler, int2(x_p, y_p + 1));
float4 input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1));
float4 input0 = read_imagef(input, sampler, (int2)(x_p, y_p));
float4 input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p));
float4 input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1));
float4 input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1));
float4 out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out), convert_half4(out_val));
write_imageh(output, (int2)(x_out, y_out), convert_half4(out_val));
x = (g1.y + 1) * (out_width - 1) / 2;
y = (g2.y + 1) * (out_height - 1) / 2;
......@@ -55,15 +55,15 @@ __kernel void grid_sampler(__private const int out_height,
y0 = floor(y);
x_p = out_c * out_width + x0;
y_p = out_n * out_height + y0;
input0 = read_imagef(input, sampler, int2(x_p, y_p));
input1 = read_imagef(input, sampler, int2(x_p + 1, y_p));
input2 = read_imagef(input, sampler, int2(x_p, y_p + 1));
input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1));
input0 = read_imagef(input, sampler, (int2)(x_p, y_p));
input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p));
input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1));
input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1));
out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out + 1), convert_half4(out_val));
write_imageh(output, (int2)(x_out, y_out + 1), convert_half4(out_val));
x = (g1.z + 1) * (out_width - 1) / 2;
y = (g2.z + 1) * (out_height - 1) / 2;
......@@ -71,15 +71,15 @@ __kernel void grid_sampler(__private const int out_height,
y0 = floor(y);
x_p = out_c * out_width + x0;
y_p = out_n * out_height + y0;
input0 = read_imagef(input, sampler, int2(x_p, y_p));
input1 = read_imagef(input, sampler, int2(x_p + 1, y_p));
input2 = read_imagef(input, sampler, int2(x_p, y_p + 1));
input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1));
input0 = read_imagef(input, sampler, (int2)(x_p, y_p));
input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p));
input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1));
input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1));
out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out + 2), convert_half4(out_val));
write_imageh(output, (int2)(x_out, y_out + 2), convert_half4(out_val));
x = (g1.w + 1) * (out_width - 1) / 2;
y = (g2.w + 1) * (out_height - 1) / 2;
......@@ -87,13 +87,13 @@ __kernel void grid_sampler(__private const int out_height,
y0 = floor(y);
x_p = out_c * out_width + x0;
y_p = out_n * out_height + y0;
input0 = read_imagef(input, sampler, int2(x_p, y_p));
input1 = read_imagef(input, sampler, int2(x_p + 1, y_p));
input2 = read_imagef(input, sampler, int2(x_p, y_p + 1));
input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1));
input0 = read_imagef(input, sampler, (int2)(x_p, y_p));
input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p));
input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1));
input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1));
out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out + 3), convert_half4(out_val));
write_imageh(output, (int2)(x_out, y_out + 3), convert_half4(out_val));
}
......@@ -24,8 +24,9 @@ void NearestInterpolationOp<DeviceType, T>::InferShape() const {
"Input(X) of BilinearInterOp should not be null.");
PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr,
"Output(Out) of BilinearInterOp should not be null.");
auto dim_x = this->param_.InputX()->dims(); // NCHW format
DLOG << "dim_x :" << dim_x;
int out_h = this->param_.OutH();
int out_w = this->param_.OutW();
PADDLE_MOBILE_ENFORCE(dim_x.size() == 4, "X's dimension must be 4");
......@@ -37,8 +38,22 @@ void NearestInterpolationOp<DeviceType, T>::InferShape() const {
"OutSize's dimension size must be 1");
PADDLE_MOBILE_ENFORCE(out_size_dim[0] == 2, "OutSize's dim[0] must be 2");
}
DLOG << "this->param_.HasScale(): " << this->param_.HasScale();
if (this->param_.HasScale()) {
const float scale = this->param_.Scale();
DLOG << "scale_: " << scale;
std::vector<int64_t> dim_out({dim_x[0], dim_x[1],
static_cast<int>(dim_x[2] * scale),
static_cast<int>(dim_x[3] * scale)});
this->param_.Out()->Resize(framework::make_ddim(dim_out));
DLOG << "interp -- dim_out: " << dim_out;
} else {
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
this->param_.Out()->Resize(framework::make_ddim(dim_out));
DLOG << "interp -- dim_out: " << dim_out;
}
}
} // namespace operators
......
......@@ -3042,7 +3042,7 @@ class SplitParam : public OpParam {
int axis;
int num;
std::vector<int> sections;
// std::vector<GType> out_ts_;
// std::vector<GType> out_ts_;
#ifdef PADDLE_MOBILE_FPGA
private:
......@@ -3103,12 +3103,20 @@ class NearestInterpolationParam : public OpParam {
out_ = OutFrom<GType>(outputs, *scope);
out_h_ = GetAttr<int>("out_h", attrs);
out_w_ = GetAttr<int>("out_w", attrs);
if (HasAttr("scale", attrs)) {
has_scale_ = true;
scale_ = GetAttr<float>("scale", attrs);
}
DLOG << "has_scale_: " << has_scale_;
DLOG << "scale_: " << scale_;
}
const GType *InputX() const { return input_x_; }
const GType *InputOutPutSize() const { return input_outsize_; }
GType *Out() const { return out_; }
int OutH() const { return out_h_; }
int OutW() const { return out_w_; }
float Scale() const { return scale_; }
bool HasScale() const { return has_scale_; }
private:
GType *input_x_;
......@@ -3116,6 +3124,8 @@ class NearestInterpolationParam : public OpParam {
GType *out_;
int out_h_;
int out_w_;
float scale_;
bool has_scale_;
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册