提交 68183f92 编写于 作者: X xiebaiyuan 提交者: GitHub

suite scale in nearest_interp_op ,test=mobile (#2633)

* suite scale in nearest_interp_op ,test=mobile

* suite scale in nearest_interp_op && fix grid_sampler_kernel.cl ,test=mobile

* suite scale in nearest_interp_op && fix grid_sampler_kernel.cl && fix code style ,test=mobile
上级 489a18b1
...@@ -28,8 +28,8 @@ __kernel void grid_sampler(__private const int out_height, ...@@ -28,8 +28,8 @@ __kernel void grid_sampler(__private const int out_height,
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int x_grid = out_h / 4 * 2; int x_grid = out_h / 4 * 2;
int y_grid = out_n * out_width + out_w; int y_grid = out_n * out_width + out_w;
float4 g1 = read_imagef(grid, sampler, int2(x_grid, y_grid)); float4 g1 = read_imagef(grid, sampler, (int2)(x_grid, y_grid));
float4 g2 = read_imagef(grid, sampler, int2(x_grid + 1, y_grid)); float4 g2 = read_imagef(grid, sampler, (int2)(x_grid + 1, y_grid));
float x = (g1.x + 1) * (out_width - 1) / 2; float x = (g1.x + 1) * (out_width - 1) / 2;
float y = (g2.x + 1) * (out_height - 1) / 2; float y = (g2.x + 1) * (out_height - 1) / 2;
...@@ -39,15 +39,15 @@ __kernel void grid_sampler(__private const int out_height, ...@@ -39,15 +39,15 @@ __kernel void grid_sampler(__private const int out_height,
int y_p = out_n * out_height + y0; int y_p = out_n * out_height + y0;
int x_out = out_c * out_width + out_w; int x_out = out_c * out_width + out_w;
int y_out = out_n * out_height + out_h; int y_out = out_n * out_height + out_h;
float4 input0 = read_imagef(input, sampler, int2(x_p, y_p)); float4 input0 = read_imagef(input, sampler, (int2)(x_p, y_p));
float4 input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); float4 input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p));
float4 input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); float4 input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1));
float4 input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); float4 input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1));
float4 out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + float4 out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) + input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) + input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0); input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out), convert_half4(out_val)); write_imageh(output, (int2)(x_out, y_out), convert_half4(out_val));
x = (g1.y + 1) * (out_width - 1) / 2; x = (g1.y + 1) * (out_width - 1) / 2;
y = (g2.y + 1) * (out_height - 1) / 2; y = (g2.y + 1) * (out_height - 1) / 2;
...@@ -55,15 +55,15 @@ __kernel void grid_sampler(__private const int out_height, ...@@ -55,15 +55,15 @@ __kernel void grid_sampler(__private const int out_height,
y0 = floor(y); y0 = floor(y);
x_p = out_c * out_width + x0; x_p = out_c * out_width + x0;
y_p = out_n * out_height + y0; y_p = out_n * out_height + y0;
input0 = read_imagef(input, sampler, int2(x_p, y_p)); input0 = read_imagef(input, sampler, (int2)(x_p, y_p));
input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p));
input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1));
input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1));
out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) + input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) + input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0); input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out + 1), convert_half4(out_val)); write_imageh(output, (int2)(x_out, y_out + 1), convert_half4(out_val));
x = (g1.z + 1) * (out_width - 1) / 2; x = (g1.z + 1) * (out_width - 1) / 2;
y = (g2.z + 1) * (out_height - 1) / 2; y = (g2.z + 1) * (out_height - 1) / 2;
...@@ -71,15 +71,15 @@ __kernel void grid_sampler(__private const int out_height, ...@@ -71,15 +71,15 @@ __kernel void grid_sampler(__private const int out_height,
y0 = floor(y); y0 = floor(y);
x_p = out_c * out_width + x0; x_p = out_c * out_width + x0;
y_p = out_n * out_height + y0; y_p = out_n * out_height + y0;
input0 = read_imagef(input, sampler, int2(x_p, y_p)); input0 = read_imagef(input, sampler, (int2)(x_p, y_p));
input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p));
input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1));
input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1));
out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) + input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) + input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0); input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out + 2), convert_half4(out_val)); write_imageh(output, (int2)(x_out, y_out + 2), convert_half4(out_val));
x = (g1.w + 1) * (out_width - 1) / 2; x = (g1.w + 1) * (out_width - 1) / 2;
y = (g2.w + 1) * (out_height - 1) / 2; y = (g2.w + 1) * (out_height - 1) / 2;
...@@ -87,13 +87,13 @@ __kernel void grid_sampler(__private const int out_height, ...@@ -87,13 +87,13 @@ __kernel void grid_sampler(__private const int out_height,
y0 = floor(y); y0 = floor(y);
x_p = out_c * out_width + x0; x_p = out_c * out_width + x0;
y_p = out_n * out_height + y0; y_p = out_n * out_height + y0;
input0 = read_imagef(input, sampler, int2(x_p, y_p)); input0 = read_imagef(input, sampler, (int2)(x_p, y_p));
input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); input1 = read_imagef(input, sampler, (int2)(x_p + 1, y_p));
input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); input2 = read_imagef(input, sampler, (int2)(x_p, y_p + 1));
input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); input3 = read_imagef(input, sampler, (int2)(x_p + 1, y_p + 1));
out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) +
input1 * (x - x0) * (y0 + 1 - y) + input1 * (x - x0) * (y0 + 1 - y) +
input2 * (x0 + 1 - x) * (y - y0) + input2 * (x0 + 1 - x) * (y - y0) +
input3 * (x - x0) * (y - y0); input3 * (x - x0) * (y - y0);
write_imageh(output, int2(x_out, y_out + 3), convert_half4(out_val)); write_imageh(output, (int2)(x_out, y_out + 3), convert_half4(out_val));
} }
...@@ -24,8 +24,9 @@ void NearestInterpolationOp<DeviceType, T>::InferShape() const { ...@@ -24,8 +24,9 @@ void NearestInterpolationOp<DeviceType, T>::InferShape() const {
"Input(X) of BilinearInterOp should not be null."); "Input(X) of BilinearInterOp should not be null.");
PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr, PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr,
"Output(Out) of BilinearInterOp should not be null."); "Output(Out) of BilinearInterOp should not be null.");
auto dim_x = this->param_.InputX()->dims(); // NCHW format auto dim_x = this->param_.InputX()->dims(); // NCHW format
DLOG << "dim_x :" << dim_x;
int out_h = this->param_.OutH(); int out_h = this->param_.OutH();
int out_w = this->param_.OutW(); int out_w = this->param_.OutW();
PADDLE_MOBILE_ENFORCE(dim_x.size() == 4, "X's dimension must be 4"); PADDLE_MOBILE_ENFORCE(dim_x.size() == 4, "X's dimension must be 4");
...@@ -37,8 +38,22 @@ void NearestInterpolationOp<DeviceType, T>::InferShape() const { ...@@ -37,8 +38,22 @@ void NearestInterpolationOp<DeviceType, T>::InferShape() const {
"OutSize's dimension size must be 1"); "OutSize's dimension size must be 1");
PADDLE_MOBILE_ENFORCE(out_size_dim[0] == 2, "OutSize's dim[0] must be 2"); PADDLE_MOBILE_ENFORCE(out_size_dim[0] == 2, "OutSize's dim[0] must be 2");
} }
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
this->param_.Out()->Resize(framework::make_ddim(dim_out)); DLOG << "this->param_.HasScale(): " << this->param_.HasScale();
if (this->param_.HasScale()) {
const float scale = this->param_.Scale();
DLOG << "scale_: " << scale;
std::vector<int64_t> dim_out({dim_x[0], dim_x[1],
static_cast<int>(dim_x[2] * scale),
static_cast<int>(dim_x[3] * scale)});
this->param_.Out()->Resize(framework::make_ddim(dim_out));
DLOG << "interp -- dim_out: " << dim_out;
} else {
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
this->param_.Out()->Resize(framework::make_ddim(dim_out));
DLOG << "interp -- dim_out: " << dim_out;
}
} }
} // namespace operators } // namespace operators
......
...@@ -3042,7 +3042,7 @@ class SplitParam : public OpParam { ...@@ -3042,7 +3042,7 @@ class SplitParam : public OpParam {
int axis; int axis;
int num; int num;
std::vector<int> sections; std::vector<int> sections;
// std::vector<GType> out_ts_; // std::vector<GType> out_ts_;
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
...@@ -3103,12 +3103,20 @@ class NearestInterpolationParam : public OpParam { ...@@ -3103,12 +3103,20 @@ class NearestInterpolationParam : public OpParam {
out_ = OutFrom<GType>(outputs, *scope); out_ = OutFrom<GType>(outputs, *scope);
out_h_ = GetAttr<int>("out_h", attrs); out_h_ = GetAttr<int>("out_h", attrs);
out_w_ = GetAttr<int>("out_w", attrs); out_w_ = GetAttr<int>("out_w", attrs);
if (HasAttr("scale", attrs)) {
has_scale_ = true;
scale_ = GetAttr<float>("scale", attrs);
}
DLOG << "has_scale_: " << has_scale_;
DLOG << "scale_: " << scale_;
} }
const GType *InputX() const { return input_x_; } const GType *InputX() const { return input_x_; }
const GType *InputOutPutSize() const { return input_outsize_; } const GType *InputOutPutSize() const { return input_outsize_; }
GType *Out() const { return out_; } GType *Out() const { return out_; }
int OutH() const { return out_h_; } int OutH() const { return out_h_; }
int OutW() const { return out_w_; } int OutW() const { return out_w_; }
float Scale() const { return scale_; }
bool HasScale() const { return has_scale_; }
private: private:
GType *input_x_; GType *input_x_;
...@@ -3116,6 +3124,8 @@ class NearestInterpolationParam : public OpParam { ...@@ -3116,6 +3124,8 @@ class NearestInterpolationParam : public OpParam {
GType *out_; GType *out_;
int out_h_; int out_h_;
int out_w_; int out_w_;
float scale_;
bool has_scale_;
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册