未验证 提交 60670acb 编写于 作者: X xiebaiyuan 提交者: GitHub

develop expend op for erciyuan ,test=mobile (#2626)

* develop expend op for erciyuan ,test=mobile

* develop expend op for erciyuan && code style ,test=mobile
上级 f3d113a1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef EXPAND_OP
#include "operators/expand_op.h"
#include <framework/ddim.h>
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ExpandOp<Dtype, T>::InferShape() const {
auto x_dim = this->param_.InputX()->dims();
int expand_size = this->param_.expand_times.size();
int x_dims_size = x_dim.size();
PADDLE_MOBILE_ENFORCE(expand_size == x_dims_size,
"The number of expand_times size must be qual to the "
"rank of Input(X). The number of expand_times size "
"must be qual to the rank of Input(X).")
framework::DDim out_dims(this->param_.InputX()->dims());
for (size_t i = 0; i < this->param_.expand_times.size(); ++i) {
out_dims[i] *= this->param_.expand_times[i];
}
this->param_.Out()->Resize(out_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(expand, ops::ExpandOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef EXPAND_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/expand_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
#ifdef EXPAND_OP
DECLARE_OPERATOR(Expand, ExpandParam, ExpandKernel);
#endif
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void expend_c1(
__private const int OUT_C, __private const int OUT_W,
__private const int OUT_NH,
__private const int IN_C, __private const int IN_W,
__private const int IN_NH,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width, __private const int output_height,
__read_only image2d_t input, __write_only image2d_t output,
__private const int n_times, __private const int c_times,
__private const int h_times, __private const int w_times) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= OUT_C || out_w >= OUT_W || out_nh >= OUT_NH) {
return;
}
const int out_n = out_nh / output_height;
const int out_h = out_nh % output_height;
// const real_in_c = out_c * 4 / c_times;
// const int in_c = real_in_c / 4;
const int in_c = 0;
// const int in_c = out_c / c_times;
const int in_w = out_w / w_times;
const int in_h = out_h / h_times;
const int in_n = out_n / n_times;
const int in_nh = in_n * input_height + in_h;
int2 output_pos = (int2)(out_c * OUT_W + out_w, out_nh);
int2 input_pos = (int2)(in_c * IN_W + in_w, in_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
half4 in = read_imageh(input, sampler, input_pos);
in.y = in.x;
in.z = in.x;
in.w = in.x;
write_imageh(output, output_pos, in);
}
__kernel void expend_c2(
__private const int OUT_C, __private const int OUT_W,
__private const int OUT_NH,
__private const int IN_C, __private const int IN_W,
__private const int IN_NH,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width, __private const int output_height,
__read_only image2d_t input, __write_only image2d_t output,
__private const int n_times, __private const int c_times,
__private const int h_times, __private const int w_times) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= OUT_C || out_w >= OUT_W || out_nh >= OUT_NH) {
return;
}
const int out_n = out_nh / output_height;
const int out_h = out_nh % output_height;
// const real_in_c = out_c * 4 / c_times;
// const int in_c = real_in_c / 4;
const int in_c = 0;
// const int in_c = out_c / c_times;
const int in_w = out_w / w_times;
const int in_h = out_h / h_times;
const int in_n = out_n / n_times;
const int in_nh = in_n * input_height + in_h;
int2 output_pos = (int2)(out_c * OUT_W + out_w, out_nh);
int2 input_pos = (int2)(in_c * IN_W + in_w, in_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
half4 in = read_imageh(input, sampler, input_pos);
in.z = in.x;
in.w = in.y;
write_imageh(output, output_pos, in);
}
__kernel void expend_c4(
__private const int OUT_C, __private const int OUT_W,
__private const int OUT_NH,
__private const int IN_C, __private const int IN_W,
__private const int IN_NH,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width, __private const int output_height,
__read_only image2d_t input, __write_only image2d_t output,
__private const int n_times, __private const int c_times,
__private const int h_times, __private const int w_times) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= OUT_C || out_w >= OUT_W || out_nh >= OUT_NH) {
return;
}
const int out_n = out_nh / output_height;
const int out_h = out_nh % output_height;
// const real_in_c = out_c * 4 / c_times;
// const int in_c = real_in_c / 4;
const int in_c = 0;
// const int in_c = out_c / c_times;
const int in_w = out_w / w_times;
const int in_h = out_h / h_times;
const int in_n = out_n / n_times;
const int in_nh = in_n * input_height + in_h;
int2 output_pos = (int2)(out_c * OUT_W + out_w, out_nh);
int2 input_pos = (int2)(in_c * IN_W + in_w, in_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
half4 in = read_imageh(input, sampler, input_pos);
write_imageh(output, output_pos, in);
}
\ No newline at end of file
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef EXPAND_OP
#include "operators/kernel/expand_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ExpandKernel<GPU_CL, float>::Init(ExpandParam<GPU_CL>* param) {
const framework::DDim& input_dims = param->InputX()->dims();
PADDLE_MOBILE_ENFORCE(input_dims.size() == 4,
"expend now support 4 size dims");
if (input_dims[1] == 1) {
this->cl_helper_.AddKernel("expend_c1", "expend.cl");
} else if (input_dims[1] == 2) {
this->cl_helper_.AddKernel("expend_c2", "expend.cl");
} else if (input_dims[1] == 4) {
this->cl_helper_.AddKernel("expend_c4", "expend.cl");
} else {
PADDLE_MOBILE_ENFORCE(false, "expend did not supported this type");
}
return true;
}
template <>
void ExpandKernel<GPU_CL, float>::Compute(const ExpandParam<GPU_CL>& param) {
auto kernel = this->cl_helper_.KernelAt(0);
DLOG << "param.Out()->dims(): " << param.Out()->dims();
const framework::DDim& image_dims = param.Out()->ImageDims();
DLOG << "param.Out()->image_dims(): " << image_dims;
auto out_work_size = this->cl_helper_.DefaultWorkSize(*param.Out());
DLOG << "out_work_size: " << out_work_size;
int out_c_block = out_work_size[0];
int out_w = out_work_size[1];
int out_nh = out_work_size[2];
auto in_work_size = this->cl_helper_.DefaultWorkSize(*param.InputX());
int in_c_block = in_work_size[0];
int in_w = in_work_size[1];
int in_nh = in_work_size[2];
int input_width = param.InputX()->dims()[3];
int input_height = param.InputX()->dims()[2];
int output_width = param.Out()->dims()[3];
int output_height = param.Out()->dims()[2];
const auto* input = param.InputX();
auto* output = param.Out();
vector<int> expandTimes = {1, 1, 1, 1};
DLOG << "param.expand_times: " << param.expand_times;
for (int i = 0; i < param.expand_times.size(); ++i) {
expandTimes[i] = param.expand_times[i];
}
DLOG << "expandTimes: " << expandTimes;
auto inputImage = input->GetCLImage();
auto outputImage = output->GetCLImage();
input->dims();
int idx = 0;
cl_int status;
status = clSetKernelArg(kernel, idx++, sizeof(int), &out_c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &out_w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &out_nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &in_c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &in_w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &in_nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &output_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(cl_mem), &inputImage);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(cl_mem), &outputImage);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &expandTimes[0]);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &expandTimes[1]);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &expandTimes[2]);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, idx++, sizeof(int), &expandTimes[3]);
CL_CHECK_ERRORS(status);
status =
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
out_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
DLOG << *output;
}
template class ExpandKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
#ifdef EXPAND_OP
DECLARE_KERNEL(Expand, ExpandParam);
#endif // EXPAND_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -3687,56 +3687,58 @@ class PixelShuffleParam : public OpParam {
};
#endif
#ifdef EXPAND_OP
#ifdef GRID_SAMPLER_OP
template <typename Dtype>
class ExpandParam : public OpParam {
class GridSamplerParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
ExpandParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
GridSamplerParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
expand_times_ = GetAttr<std::vector<int>>("expand_times", attrs);
output_ = OutputFrom<GType>(outputs, *scope);
}
const GType *InputX() const { return input_x_; }
GType *Out() const { return out_; }
GType *Output() const { return output_; }
private:
GType *input_x_;
GType *out_;
std::vector<int> expand_times_;
GType *output_;
};
#endif
#ifdef GRID_SAMPLER_OP
#ifdef EXPAND_OP
template <typename Dtype>
class GridSamplerParam : public OpParam {
class ExpandParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
GridSamplerParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
Scope *scope)
ExpandParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = InputXFrom<GType>(inputs, *scope);
output_ = OutputFrom<GType>(outputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
expand_times = OpParam::GetAttr<std::vector<int>>("expand_times", attrs);
}
const GType *InputX() const { return input_x_; }
GType *Output() const { return output_; }
GType *Out() const { return out_; }
std::vector<int> expand_times;
private:
GType *input_x_;
GType *output_;
GType *out_;
};
#endif
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -237,6 +237,10 @@ if (ENABLE_ALL_TEST)
ADD_EXECUTABLE(test-conv-op operators/test_conv_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-conv-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-expend-op operators/test_expend_op.cpp test_helper.h test_include.h executor_for_test_opencl.h )
target_link_libraries(test-expend-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-mul-op operators/test_mul_op.cpp test_helper.h test_include.h)
target_link_libraries(test-mul-op paddle-mobile)
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include <memory>
#include "common/log.h"
#include "framework/executor.h"
#include "framework/op_registry.h"
......@@ -74,8 +74,11 @@ class Executor4Test : public Executor<DeviceType> {
break;
}
}
if (this->program_.combined) {
this->InitCombineMemory();
} else {
this->InitMemory();
}
for (const auto &op : this->ops_of_block0_) {
op->Init();
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_MOBILE_CL
#include <string>
#include <vector>
#include <memory>
#include "common/log.h"
#include "framework/cl/cl_helper.h"
#include "framework/cl/cl_tensor.h"
#include "framework/executor.h"
#include "framework/op_registry.h"
#include "operators/feed_op.h"
#include "operators/fetch_op.h"
#include "./test_helper.h"
using paddle_mobile::framework::BlockDesc;
using paddle_mobile::framework::DDim;
using paddle_mobile::framework::Executor;
using paddle_mobile::framework::LoDTensor;
using paddle_mobile::framework::OpDesc;
using paddle_mobile::framework::Program;
using paddle_mobile::framework::Tensor;
using paddle_mobile::framework::Variable;
using paddle_mobile::framework::OperatorBase;
using paddle_mobile::framework::AttributeMap;
using std::string;
using std::vector;
namespace paddle_mobile {
template <typename OpType>
class OpenClOpTester {
public:
OpenClOpTester() {
framework::CLEngine::Instance()->setClPath("/data/local/tmp/bin");
scope_ = std::make_shared<paddle_mobile::framework::Scope>();
feed_clhelper_ = framework::CLHelper(scope_->GetCLScpoe());
fetch_clhelper_ = framework::CLHelper(scope_->GetCLScpoe());
this->feed_clhelper_.AddKernel("feed", "feed_kernel.cl");
this->fetch_clhelper_.AddKernel("fetch", "fetch_kernel.cl");
feed_var = scope_.get()->Var("feed");
fetch_var = scope_.get()->Var("fetch");
op_in_var = scope_.get()->Var("op_in");
op_out_var = scope_.get()->Var("op_out");
}
void Predict(string op_type, DDim feed_dims, DDim fetch_dims,
VariableNameMap inputs_feed, VariableNameMap outputs_feed,
AttributeMap attrs_feed) {
framework::CLImage *const op_in_cl_image =
op_in_var->template GetMutable<framework::CLImage>();
op_in_cl_image->Resize(feed_dims);
op_in_cl_image->InitEmptyImage(feed_clhelper_.CLContext(),
feed_clhelper_.CLCommandQueue(), feed_dims);
framework::CLImage *const op_out_cl_image =
op_out_var->template GetMutable<framework::CLImage>();
op_out_cl_image->Resize(fetch_dims);
framework::CLScope *const clScpoe = scope_->GetCLScpoe();
op_out_cl_image->InitEmptyImage(clScpoe->Context(), clScpoe->CommandQueue(),
fetch_dims);
Feed(feed_dims);
auto *op = new OpType(op_type, inputs_feed, outputs_feed, attrs_feed,
scope_.get());
op->InferShape();
op->Init();
op->Run();
Fetch(fetch_dims);
}
void Feed(DDim feed_dims) {
auto *feed_var = scope_->Var("feed");
auto *_var = scope_->Var("op_in");
auto *const input = feed_var->template GetMutable<framework::LoDTensor>();
DLOG << "feed_dims: " << feed_dims;
SetupTensor<float>(input, feed_dims, -100.0, 100.0);
framework::CLImage *const op_in_cl_image =
op_in_var->template GetMutable<framework::CLImage>();
DLOG << "FeedKernel run ";
DLOG << "params.input " << *input;
DLOG << "params.op_in_cl_image " << *op_in_cl_image;
auto kernel = this->feed_clhelper_.KernelAt(0);
DLOG << "kernel get success ";
auto default_work_size =
this->feed_clhelper_.DefaultWorkSize(*(op_in_cl_image));
DLOG << "op_in_cl_image: " << *op_in_cl_image;
DLOG << "default_work_size: " << default_work_size;
cl_int status;
int numel = input->numel();
cl_mem output_image = op_in_cl_image->GetCLImage();
const int out_C = op_in_cl_image->dims()[1];
const int out_H = op_in_cl_image->dims()[2];
const int out_W = op_in_cl_image->dims()[3];
const int Stride2 = out_C * out_H * out_W;
const int Stride1 = out_H * out_W;
const int Stride0 = out_W;
framework::CLTensor input_cl_tensor(this->feed_clhelper_.CLContext(),
this->feed_clhelper_.CLCommandQueue());
input_cl_tensor.Resize(input->dims());
cl_mem inputBuffer;
inputBuffer =
input_cl_tensor.mutable_with_data<float>(input->data<float>());
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputBuffer);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(cl_int), &out_H);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_int), &out_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_int), &out_C);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_int), &Stride0);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_int), &Stride1);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_int), &Stride2);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->feed_clhelper_.CLCommandQueue(), kernel, default_work_size.size(),
NULL, default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
DLOG << "*op_in_cl_image: " << *op_in_cl_image;
}
void Fetch(DDim fetch_dims) {
DLOG << "------------------ Fetch op ---------------------";
DLOG << "------------------ Fetch op end ---------------------";
}
private:
std::shared_ptr<paddle_mobile::framework::Scope> scope_;
framework::CLHelper feed_clhelper_;
framework::CLHelper fetch_clhelper_;
Variable *feed_var;
Variable *fetch_var;
Variable *op_in_var;
Variable *op_out_var;
};
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../executor_for_test_opencl.h"
#include "operators/expand_op.h"
#include "operators/feed_op.h"
#ifdef EXPAND_OP
int main() {
const int IN_N = 1;
const int IN_C = 1;
const int IN_H = 2;
const int IN_W = 3;
const int EXPEND_N = 1;
const int EXPEND_C = 1;
const int EXPEND_H = 2;
const int EXPEND_W = 2;
const int OUT_N = IN_N * EXPEND_N;
const int OUT_C = IN_C * EXPEND_C;
const int OUT_H = IN_H * EXPEND_H;
const int OUT_W = IN_W * EXPEND_W;
framework::DDim in_dims = framework::make_ddim({IN_N, IN_C, IN_H, IN_W});
framework::DDim out_dims = framework::make_ddim({OUT_N, OUT_C, OUT_H, OUT_W});
VariableNameMap inputs;
VariableNameMap outputs;
AttributeMap attrs;
inputs["X"] = std::vector<std::string>({"op_in"});
outputs["Out"] = std::vector<std::string>({"op_out"});
std::vector<int> expand_times = {EXPEND_N, EXPEND_C, EXPEND_H, EXPEND_W};
attrs["expand_times"].Set<std::vector<int>>(expand_times);
OpenClOpTester<operators::ExpandOp<GPU_CL, float>> tester;
tester.Predict("expend", in_dims, out_dims, inputs, outputs, attrs);
}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册