未验证 提交 d1258513 编写于 作者: W Wilber 提交者: GitHub

[Kernel] Update gru op and batch_norm op. (#4088)

上级 0ac48fcb
...@@ -55,7 +55,9 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( ...@@ -55,7 +55,9 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
output_data + output_data +
(gpu_input_offset_l[blockIdx.x] + idx) * feat_map_num * topk_size + (gpu_input_offset_l[blockIdx.x] + idx) * feat_map_num * topk_size +
blockIdx.y * topk_size; blockIdx.y * topk_size;
for (int i = 0; i < topk_size; ++i) {
fm_row_out_data[i] = 0;
}
Dtype *smem_start_col = smem + idx * col_max; Dtype *smem_start_col = smem + idx * col_max;
int counter = max_k; // topk_size; int counter = max_k; // topk_size;
...@@ -151,6 +153,9 @@ __global__ void topk_avg_pooling_kernel_for_big_data( ...@@ -151,6 +153,9 @@ __global__ void topk_avg_pooling_kernel_for_big_data(
blockIdx.z * actual_row_in_shared_mem + idx) * blockIdx.z * actual_row_in_shared_mem + idx) *
feat_map_num * topk_size + feat_map_num * topk_size +
blockIdx.y * topk_size; blockIdx.y * topk_size;
for (int i = 0; i < topk_size; ++i) {
fm_row_out_data[i] = 0;
}
Dtype *smem_start_col = smem + idx * col_max; Dtype *smem_start_col = smem + idx * col_max;
...@@ -239,8 +244,6 @@ void SequenceTopkAvgPoolingCompute<T>::Run() { ...@@ -239,8 +244,6 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
Tensor *out_tensor = param.Out; Tensor *out_tensor = param.Out;
const T *in_data = x_tensor->data<T>(); const T *in_data = x_tensor->data<T>();
T *out_data = out_tensor->mutable_data<T>(TARGET(kCUDA)); T *out_data = out_tensor->mutable_data<T>(TARGET(kCUDA));
TargetWrapperCuda::MemsetAsync(
out_data, 0, sizeof(T) * param.Out->numel(), cuda_stream);
int topk_num = param.topks.size(); int topk_num = param.topks.size();
lite::DDim top_ks_shape(std::vector<int64_t>{topk_num, 1, 1, 1}); lite::DDim top_ks_shape(std::vector<int64_t>{topk_num, 1, 1, 1});
......
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/operators/dropout_op.h" #include "lite/operators/dropout_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -66,8 +68,10 @@ bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { ...@@ -66,8 +68,10 @@ bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.fix_seed = op_desc.GetAttr<bool>("fix_seed"); param_.fix_seed = op_desc.GetAttr<bool>("fix_seed");
param_.seed = op_desc.GetAttr<int>("seed"); param_.seed = op_desc.GetAttr<int>("seed");
param_.dropout_implementation = if (op_desc.HasAttr("dropout_implementation")) {
op_desc.GetAttr<std::string>("dropout_implementation"); param_.dropout_implementation =
op_desc.GetAttr<std::string>("dropout_implementation");
}
return true; return true;
} }
......
...@@ -97,7 +97,9 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { ...@@ -97,7 +97,9 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.gate_activation = op_desc.GetAttr<std::string>("gate_activation"); param_.gate_activation = op_desc.GetAttr<std::string>("gate_activation");
param_.activation = op_desc.GetAttr<std::string>("activation"); param_.activation = op_desc.GetAttr<std::string>("activation");
param_.is_reverse = op_desc.GetAttr<bool>("is_reverse"); param_.is_reverse = op_desc.GetAttr<bool>("is_reverse");
param_.origin_mode = op_desc.GetAttr<bool>("origin_mode"); if (op_desc.HasAttr("origin_mode")) {
param_.origin_mode = op_desc.GetAttr<bool>("origin_mode");
}
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册