未验证 提交 d1258513 编写于 作者: W Wilber 提交者: GitHub

[Kernel] Update gru op and batch_norm op. (#4088)

上级 0ac48fcb
......@@ -55,7 +55,9 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
output_data +
(gpu_input_offset_l[blockIdx.x] + idx) * feat_map_num * topk_size +
blockIdx.y * topk_size;
for (int i = 0; i < topk_size; ++i) {
fm_row_out_data[i] = 0;
}
Dtype *smem_start_col = smem + idx * col_max;
int counter = max_k; // topk_size;
......@@ -151,6 +153,9 @@ __global__ void topk_avg_pooling_kernel_for_big_data(
blockIdx.z * actual_row_in_shared_mem + idx) *
feat_map_num * topk_size +
blockIdx.y * topk_size;
for (int i = 0; i < topk_size; ++i) {
fm_row_out_data[i] = 0;
}
Dtype *smem_start_col = smem + idx * col_max;
......@@ -239,8 +244,6 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
Tensor *out_tensor = param.Out;
const T *in_data = x_tensor->data<T>();
T *out_data = out_tensor->mutable_data<T>(TARGET(kCUDA));
TargetWrapperCuda::MemsetAsync(
out_data, 0, sizeof(T) * param.Out->numel(), cuda_stream);
int topk_num = param.topks.size();
lite::DDim top_ks_shape(std::vector<int64_t>{topk_num, 1, 1, 1});
......
......@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/dropout_op.h"
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
......@@ -66,8 +68,10 @@ bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.fix_seed = op_desc.GetAttr<bool>("fix_seed");
param_.seed = op_desc.GetAttr<int>("seed");
param_.dropout_implementation =
op_desc.GetAttr<std::string>("dropout_implementation");
if (op_desc.HasAttr("dropout_implementation")) {
param_.dropout_implementation =
op_desc.GetAttr<std::string>("dropout_implementation");
}
return true;
}
......
......@@ -97,7 +97,9 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.gate_activation = op_desc.GetAttr<std::string>("gate_activation");
param_.activation = op_desc.GetAttr<std::string>("activation");
param_.is_reverse = op_desc.GetAttr<bool>("is_reverse");
param_.origin_mode = op_desc.GetAttr<bool>("origin_mode");
if (op_desc.HasAttr("origin_mode")) {
param_.origin_mode = op_desc.GetAttr<bool>("origin_mode");
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册