提交 d3dcbd37 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into move_sgd_to_phi

...@@ -52,8 +52,12 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor, ...@@ -52,8 +52,12 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
} }
} }
static void RetainGradForRegularNode( void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
const paddle::experimental::Tensor& tensor) { if (IsLeafTensor(tensor)) {
// Leaf tensor's grad will always be retained
// Refer to implementation of AccumulationNode for more details
return;
} else {
AutogradMeta* meta = EagerUtils::unsafe_autograd_meta(tensor); AutogradMeta* meta = EagerUtils::unsafe_autograd_meta(tensor);
if (meta->RetainGrads()) { if (meta->RetainGrads()) {
return; return;
...@@ -86,15 +90,6 @@ static void RetainGradForRegularNode( ...@@ -86,15 +90,6 @@ static void RetainGradForRegularNode(
// Append to GradientHooks // Append to GradientHooks
RegisterGradientHookForTensor(tensor, RegisterGradientHookForTensor(tensor,
std::make_shared<egr::CppTensorHook>(hook)); std::make_shared<egr::CppTensorHook>(hook));
}
void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
if (IsLeafTensor(tensor)) {
// Leaf tensor's grad will always be retained
// Refer to implementation of AccumulationNode for more details
return;
} else {
RetainGradForRegularNode(tensor);
} }
} }
......
...@@ -1156,11 +1156,13 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1156,11 +1156,13 @@ static std::string GenerateGradNodeCreationContent(
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_OUT_RANK_TEMPLATE, output_autograd_name, output_position); SET_OUT_RANK_TEMPLATE, output_autograd_name, output_position);
// Intermediate Tensor does not require SetHistory
if (!output.intermediate()) {
const char* SET_HISTORY_TEMPLATE = const char* SET_HISTORY_TEMPLATE =
" egr::EagerUtils::SetHistory(&%s, grad_node);\n"; " egr::EagerUtils::SetHistory(&%s, grad_node);\n";
grad_node_creation_str += grad_node_creation_str +=
paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name); paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name);
}
const char* SET_GRAD_IN_META_TEMPLATE = const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(&%s, %d);\n"; " grad_node->SetGradInMeta(&%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
...@@ -1173,17 +1175,20 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1173,17 +1175,20 @@ static std::string GenerateGradNodeCreationContent(
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_OUT_RANK_TEMPLATE, output_autograd_name, output_position); SET_OUT_RANK_TEMPLATE, output_autograd_name, output_position);
// Intermediate Tensor does not require SetHistory
if (!output.intermediate()) {
const char* SET_HISTORY_TEMPLATE = const char* SET_HISTORY_TEMPLATE =
" egr::EagerUtils::SetHistory(%s, grad_node);\n"; " egr::EagerUtils::SetHistory(%s, grad_node);\n";
grad_node_creation_str += grad_node_creation_str +=
paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name); paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name);
}
const char* SET_GRAD_IN_META_TEMPLATE = const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n"; " grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position); SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position);
} }
// Intermediate Tensor does not require CheckAndRetainGrad
if (!output.intermediate()) { if (!output.intermediate()) {
VLOG(6) << "Generated Call RetainGradForTensor"; VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE = const char* RETAIN_GRAD_TEMPLATE =
......
...@@ -24,6 +24,17 @@ core_ops_args_info = {} ...@@ -24,6 +24,17 @@ core_ops_args_info = {}
core_ops_args_type_info = {} core_ops_args_type_info = {}
yaml_types_mapping = {
'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'Backend' : 'Backend', 'DataLayout' : 'DataLayout', 'DataType' : 'DataType', \
'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor',
'Tensor[]' : 'std::vector<Tensor>',
'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>'
}
def ParseArguments(): def ParseArguments():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Eager Code Generator Args Parser') description='Eager Code Generator Args Parser')
...@@ -59,7 +70,9 @@ def IsPlainTensorType(string): ...@@ -59,7 +70,9 @@ def IsPlainTensorType(string):
def IsVectorTensorType(string): def IsVectorTensorType(string):
vector_tensor_types = ['list(Tensor)'] vector_tensor_types = [
'std::vector<std::vector<Tensor>>', 'std::vector<Tensor>'
]
if string in vector_tensor_types: if string in vector_tensor_types:
return True return True
return False return False
...@@ -180,6 +193,9 @@ def ParseYamlArgs(string): ...@@ -180,6 +193,9 @@ def ParseYamlArgs(string):
arg_name = m.group(3).split("=")[0].strip() arg_name = m.group(3).split("=")[0].strip()
default_value = m.group(3).split("=")[1].strip() if len( default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None m.group(3).split("=")) > 1 else None
assert arg_type in yaml_types_mapping.keys()
arg_type = yaml_types_mapping[arg_type]
if "Tensor" in arg_type: if "Tensor" in arg_type:
assert default_value is None assert default_value is None
inputs_list.append([arg_name, arg_type, i]) inputs_list.append([arg_name, arg_type, i])
...@@ -219,6 +235,10 @@ def ParseYamlReturnsWithName(string): ...@@ -219,6 +235,10 @@ def ParseYamlReturnsWithName(string):
m = re.search(pattern, ret) m = re.search(pattern, ret)
ret_type = m.group(1) ret_type = m.group(1)
ret_name = m.group(2) ret_name = m.group(2)
assert ret_type in yaml_types_mapping.keys()
ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type assert "Tensor" in ret_type
returns_list.append([ret_name, ret_type, i]) returns_list.append([ret_name, ret_type, i])
......
...@@ -221,10 +221,11 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -221,10 +221,11 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
<< " 's name is: " << grad_output_tensor.name(); << " 's name is: " << grad_output_tensor.name();
auto* next_node = next_node_shared.get(); auto* next_node = next_node_shared.get();
if (!node_input_buffers_dict.count(next_node)) { if (!node_input_buffers_dict.count(next_node)) {
node_input_buffers_dict[next_node] = const auto& input_meta = next_node->InputMeta();
std::make_unique<GradTensorHolder>(next_node->InputMeta()); auto grad_tensor_holder =
std::make_unique<GradTensorHolder>(input_meta);
node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
} }
VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
<< ", rank: " << edge_rank.second; << ", rank: " << edge_rank.second;
......
...@@ -244,7 +244,7 @@ GradNodeBase::ApplyGradientHooks( ...@@ -244,7 +244,7 @@ GradNodeBase::ApplyGradientHooks(
if (!out.defined() || !out.initialized()) { if (!out.defined() || !out.initialized()) {
out = (*hook)(tensors[slot_id][rank]); out = (*hook)(tensors[slot_id][rank]);
} else { } else {
// If more than one hook is registered, the input to the next hook func // If more than one hook is registered, the input to the next hook func
// should be the output of the previous hook // should be the output of the previous hook
out = (*hook)(out); out = (*hook)(out);
} }
......
...@@ -122,12 +122,21 @@ paddle::experimental::Tensor* EagerUtils::mutable_grad( ...@@ -122,12 +122,21 @@ paddle::experimental::Tensor* EagerUtils::mutable_grad(
void EagerUtils::SetHistory(std::vector<AutogradMeta*>* autograd_metas, void EagerUtils::SetHistory(std::vector<AutogradMeta*>* autograd_metas,
const std::shared_ptr<GradNodeBase>& grad_node) { const std::shared_ptr<GradNodeBase>& grad_node) {
for (const auto& autograd_meta : *autograd_metas) { for (const auto& autograd_meta : *autograd_metas) {
if (dynamic_cast<GradNodeAccumulation*>(autograd_meta->GradNode())) {
VLOG(6) << "Warning: Reseting GradNodeAccumulation for leaf tensor is "
"detected";
}
autograd_meta->SetGradNode(grad_node); autograd_meta->SetGradNode(grad_node);
} }
} }
void EagerUtils::SetHistory(AutogradMeta* autograd_meta, void EagerUtils::SetHistory(AutogradMeta* autograd_meta,
const std::shared_ptr<GradNodeBase>& grad_node) { const std::shared_ptr<GradNodeBase>& grad_node) {
if (dynamic_cast<GradNodeAccumulation*>(autograd_meta->GradNode())) {
VLOG(6)
<< "Warning: Reseting GradNodeAccumulation for leaf tensor is detected";
}
autograd_meta->SetGradNode(grad_node); autograd_meta->SetGradNode(grad_node);
} }
......
...@@ -88,6 +88,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -88,6 +88,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
return var_types[0] == proto::VarType::SELECTED_ROWS; return var_types[0] == proto::VarType::SELECTED_ROWS;
} }
bool IsForInferShape() const override { return true; }
private: private:
const InferShapeContext& ctx_; const InferShapeContext& ctx_;
}; };
...@@ -127,7 +129,9 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -127,7 +129,9 @@ class CompatMetaTensor : public phi::MetaTensor {
} }
} else { } else {
auto* var = BOOST_GET_CONST(VarDesc*, var_); auto* var = BOOST_GET_CONST(VarDesc*, var_);
return phi::make_ddim(var->GetShape());
return var->GetShape().empty() ? phi::make_ddim({0UL})
: phi::make_ddim(var->GetShape());
} }
} }
......
...@@ -489,6 +489,8 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -489,6 +489,8 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
return ctx_.OutputVar(name)->IsType<phi::SelectedRows>(); return ctx_.OutputVar(name)->IsType<phi::SelectedRows>();
} }
bool IsForInferShape() const override { return false; }
private: private:
const ExecutionContext& ctx_; const ExecutionContext& ctx_;
}; };
......
...@@ -125,6 +125,15 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -125,6 +125,15 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
} }
#endif
#ifdef PADDLE_WITH_IPU
if (platform::is_ipu_place(expected_kernel_key.place_)) {
VLOG(3) << "pten missing IPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
}
#endif #endif
return phi::KernelKey(); return phi::KernelKey();
} }
......
...@@ -453,6 +453,23 @@ if(WITH_MKLDNN) ...@@ -453,6 +453,23 @@ if(WITH_MKLDNN)
download_int8_data_without_verify(${INT8_GOOGLENET_MODEL_DIR} "GoogleNet_int8_model.tar.gz" ) download_int8_data_without_verify(${INT8_GOOGLENET_MODEL_DIR} "GoogleNet_int8_model.tar.gz" )
inference_analysis_api_int8_test_run_custom_warmup_batch_size(test_analyzer_int8_googlenet ${INT8_IMG_CLASS_TEST_APP} ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} 10) inference_analysis_api_int8_test_run_custom_warmup_batch_size(test_analyzer_int8_googlenet ${INT8_IMG_CLASS_TEST_APP} ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} 10)
# mobilenetv3_large_x1_0 int8
set(INT8_MOBILENETV3_LARGE_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv3_large")
set(INT8_MOBILENETV3_FILE_NAME "MobileNetV3_large_x1_0_infer.tar")
if (NOT EXISTS ${INT8_MOBILENETV3_LARGE_MODEL_DIR}/${INT8_MOBILENETV3_FILE_NAME})
inference_download_and_uncompress_without_verify(${INT8_MOBILENETV3_LARGE_MODEL_DIR} "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/" ${INT8_MOBILENETV3_FILE_NAME})
endif()
inference_analysis_test_run(test_analyzer_int8_mobilenetv3_large
COMMAND ${INT8_IMG_CLASS_TEST_APP}
ARGS --infer_model=${INT8_MOBILENETV3_LARGE_MODEL_DIR}/MobileNetV3_large_x1_0_infer
--infer_data=${IMAGENET_DATA_PATH}
--warmup_batch_size=50
--batch_size=1
--enable_int8=true
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=100
--with_accuracy_layer=false)
### BFLOAT16 tests ### BFLOAT16 tests
# build test binary to be used in subsequent tests # build test binary to be used in subsequent tests
...@@ -472,6 +489,17 @@ if(WITH_MKLDNN) ...@@ -472,6 +489,17 @@ if(WITH_MKLDNN)
# mobilenetv2 bfloat16 # mobilenetv2 bfloat16
inference_analysis_api_bfloat16_test_run(test_analyzer_bfloat16_mobilenetv2 ${BF16_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH}) inference_analysis_api_bfloat16_test_run(test_analyzer_bfloat16_mobilenetv2 ${BF16_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH})
# mobilenetv3_large
inference_analysis_test_run(test_analyzer_bfloat16_mobilenetv3_large
COMMAND ${BF16_IMG_CLASS_TEST_APP}
ARGS --infer_model=${INT8_MOBILENETV3_LARGE_MODEL_DIR}/MobileNetV3_large_x1_0_infer
--infer_data=${IMAGENET_DATA_PATH}
--batch_size=1
--enable_bf16=true
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=100
--with_accuracy_layer=false)
### Object detection models ### Object detection models
set(PASCALVOC_DATA_PATH "${INT8_DATA_DIR}/pascalvoc_val_head_300.bin") set(PASCALVOC_DATA_PATH "${INT8_DATA_DIR}/pascalvoc_val_head_300.bin")
set(INT8_OBJ_DETECT_TEST_APP "test_analyzer_int8_object_detection") set(INT8_OBJ_DETECT_TEST_APP "test_analyzer_int8_object_detection")
...@@ -739,6 +767,7 @@ if(WITH_MKLDNN) ...@@ -739,6 +767,7 @@ if(WITH_MKLDNN)
set_tests_properties(test_analyzer_quant_performance_benchmark PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_quant_performance_benchmark PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_int8_mobilenetv2 PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_int8_mobilenetv2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_int8_mobilenetv1 PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_int8_mobilenetv1 PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_int8_mobilenetv3_large PROPERTIES TIMEOUT 120)
endif() endif()
set_tests_properties(lite_resnet50_test PROPERTIES TIMEOUT 120) set_tests_properties(lite_resnet50_test PROPERTIES TIMEOUT 120)
......
...@@ -14,13 +14,19 @@ limitations under the License. */ ...@@ -14,13 +14,19 @@ limitations under the License. */
#include "paddle/fluid/inference/api/paddle_analysis_config.h" #include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h" #include "paddle/fluid/inference/tests/api/tester_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
void SetConfig(AnalysisConfig *cfg) { void SetConfig(AnalysisConfig *cfg) {
std::ifstream model_file(FLAGS_infer_model + "/__model__");
if (model_file.good())
cfg->SetModel(FLAGS_infer_model); cfg->SetModel(FLAGS_infer_model);
else
cfg->SetModel(FLAGS_infer_model + "/inference.pdmodel",
FLAGS_infer_model + "/inference.pdiparams");
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
...@@ -38,7 +44,12 @@ TEST(Analyzer_bfloat16_image_classification, bfloat16) { ...@@ -38,7 +44,12 @@ TEST(Analyzer_bfloat16_image_classification, bfloat16) {
// read data from file and prepare batches with test data // read data from file and prepare batches with test data
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInputs(&input_slots_all); SetInputs(&input_slots_all);
if (FLAGS_enable_bf16 &&
platform::MayIUse(platform::cpu_isa_t::avx512_bf16)) {
b_cfg.EnableMkldnnBfloat16(); b_cfg.EnableMkldnnBfloat16();
} else {
FLAGS_enable_bf16 = false;
}
CompareBFloat16AndAnalysis(&cfg, &b_cfg, input_slots_all); CompareBFloat16AndAnalysis(&cfg, &b_cfg, input_slots_all);
} }
......
...@@ -22,7 +22,12 @@ namespace inference { ...@@ -22,7 +22,12 @@ namespace inference {
namespace analysis { namespace analysis {
void SetConfig(AnalysisConfig *cfg) { void SetConfig(AnalysisConfig *cfg) {
std::ifstream model_file(FLAGS_infer_model + "/__model__");
if (model_file.good())
cfg->SetModel(FLAGS_infer_model); cfg->SetModel(FLAGS_infer_model);
else
cfg->SetModel(FLAGS_infer_model + "/inference.pdmodel",
FLAGS_infer_model + "/inference.pdiparams");
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
......
...@@ -213,15 +213,15 @@ std::shared_ptr<std::vector<PaddleTensor>> GetWarmupData( ...@@ -213,15 +213,15 @@ std::shared_ptr<std::vector<PaddleTensor>> GetWarmupData(
element_in_batch * 3 * 224 * 224, element_in_batch * 3 * 224 * 224,
3 * 224 * 224, 3 * 224 * 224,
static_cast<float *>(images.data.data()) + i * 3 * 224 * 224); static_cast<float *>(images.data.data()) + i * 3 * 224 * 224);
if (FLAGS_with_accuracy_layer)
std::copy_n(static_cast<int64_t *>(test_data[batch][1].data.data()) + std::copy_n(static_cast<int64_t *>(test_data[batch][1].data.data()) +
element_in_batch, element_in_batch,
1, static_cast<int64_t *>(labels.data.data()) + i); 1, static_cast<int64_t *>(labels.data.data()) + i);
} }
auto warmup_data = std::make_shared<std::vector<PaddleTensor>>(
auto warmup_data = std::make_shared<std::vector<PaddleTensor>>(2); FLAGS_with_accuracy_layer ? 2 : 1);
(*warmup_data)[0] = std::move(images); (*warmup_data)[0] = std::move(images);
(*warmup_data)[1] = std::move(labels); if (FLAGS_with_accuracy_layer) (*warmup_data)[1] = std::move(labels);
return warmup_data; return warmup_data;
} }
...@@ -254,9 +254,13 @@ void SetInputs(std::vector<std::vector<PaddleTensor>> *inputs, ...@@ -254,9 +254,13 @@ void SetInputs(std::vector<std::vector<PaddleTensor>> *inputs,
} }
for (auto i = 0; i < iterations; i++) { for (auto i = 0; i < iterations; i++) {
auto images = image_reader.NextBatch(); auto images = image_reader.NextBatch();
std::vector<PaddleTensor> tmp_vec;
tmp_vec.push_back(std::move(images));
if (FLAGS_with_accuracy_layer) {
auto labels = label_reader.NextBatch(); auto labels = label_reader.NextBatch();
inputs->emplace_back( tmp_vec.push_back(std::move(labels));
std::vector<PaddleTensor>{std::move(images), std::move(labels)}); }
inputs->push_back(std::move(tmp_vec));
} }
} }
...@@ -825,6 +829,7 @@ void CompareQuantizedAndAnalysis( ...@@ -825,6 +829,7 @@ void CompareQuantizedAndAnalysis(
SummarizePerformance("FP32", sample_latency_fp32, "INT8", SummarizePerformance("FP32", sample_latency_fp32, "INT8",
sample_latency_int8); sample_latency_int8);
if (FLAGS_with_accuracy_layer)
CompareAccuracy(quantized_outputs, analysis_outputs, compared_idx); CompareAccuracy(quantized_outputs, analysis_outputs, compared_idx);
} }
...@@ -864,6 +869,7 @@ void CompareBFloat16AndAnalysis( ...@@ -864,6 +869,7 @@ void CompareBFloat16AndAnalysis(
SummarizePerformance("FP32", sample_latency_fp32, "BF16", SummarizePerformance("FP32", sample_latency_fp32, "BF16",
sample_latency_bf16); sample_latency_bf16);
if (FLAGS_with_accuracy_layer)
CompareAccuracy(bf16_outputs, analysis_outputs, compared_idx); CompareAccuracy(bf16_outputs, analysis_outputs, compared_idx);
} }
......
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -92,9 +96,13 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker { ...@@ -92,9 +96,13 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_mean"; } virtual std::string GetOpType() const { return "Reduce reduce_mean"; }
}; };
DELCARE_INFER_SHAPE_FUNCTOR(reduce_mean, ReduceMeanInferShapeFunctor,
PT_INFER_META(phi::MeanRawInferMeta));
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__, REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__,
ops::ReduceMeanOpGradMaker<paddle::framework::OpDesc>, ops::ReduceMeanOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceMeanOpGradMaker<paddle::imperative::OpBase>); ops::ReduceMeanOpGradMaker<paddle::imperative::OpBase>,
ReduceMeanInferShapeFunctor);
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp, REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
ops::ReduceMeanDoubleGradDescMaker, ops::ReduceMeanDoubleGradDescMaker,
ops::ReduceMeanDoubleGradOpBaseMaker, ops::ReduceMeanDoubleGradOpBaseMaker,
......
...@@ -16,6 +16,10 @@ ...@@ -16,6 +16,10 @@
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDesc; class OpDesc;
...@@ -98,10 +102,14 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker { ...@@ -98,10 +102,14 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_sum"; } virtual std::string GetOpType() const { return "Reduce reduce_sum"; }
}; };
DELCARE_INFER_SHAPE_FUNCTOR(reduce_sum, ReduceSumInferShapeFunctor,
PT_INFER_META(phi::ReduceInferMetaBase));
REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumVarTypeInference, ops::ReduceSumVarTypeInference,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>, ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>); ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>,
ReduceSumInferShapeFunctor);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumDoubleOpGradMaker<paddle::framework::OpDesc>, ops::ReduceSumDoubleOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumDoubleOpGradMaker<paddle::imperative::OpBase>, ops::ReduceSumDoubleOpGradMaker<paddle::imperative::OpBase>,
......
...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/selu_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -127,9 +125,3 @@ REGISTER_OPERATOR(selu, ops::SeluOp, ops::SeluOpMaker, ops::SeluOpInferVarType, ...@@ -127,9 +125,3 @@ REGISTER_OPERATOR(selu, ops::SeluOp, ops::SeluOpMaker, ops::SeluOpInferVarType,
ops::SeluGradMaker<paddle::framework::OpDesc>, ops::SeluGradMaker<paddle::framework::OpDesc>,
ops::SeluGradMaker<paddle::imperative::OpBase>); ops::SeluGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(selu_grad, ops::SeluGradOp); REGISTER_OPERATOR(selu_grad, ops::SeluGradOp);
REGISTER_OP_CPU_KERNEL(
selu, ops::SeluKernel<paddle::platform::CPUDeviceContext, float>,
ops::SeluKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
selu_grad, ops::SeluGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SeluGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/selu_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
selu, ops::SeluKernel<paddle::platform::CUDADeviceContext, float>,
ops::SeluKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
selu_grad, ops::SeluGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SeluGradKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -24,7 +24,8 @@ namespace paddle { ...@@ -24,7 +24,8 @@ namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
struct IpuStrategy { class IpuStrategy {
public:
IpuStrategy(); IpuStrategy();
// TODO(alleng) create PaddleOptions // TODO(alleng) create PaddleOptions
...@@ -75,22 +76,30 @@ struct IpuStrategy { ...@@ -75,22 +76,30 @@ struct IpuStrategy {
// custom ops // custom ops
std::vector<IpuCustomOpIdentifier> custom_ops; std::vector<IpuCustomOpIdentifier> custom_ops;
private: public:
std::map<std::string, std::function<void(bool)>> bool_options; void AddBoolOption(const std::string &option, bool value);
std::map<std::string, std::function<void(std::uint64_t)>> uint64_options; void AddUint64Option(const std::string &option, std::uint64_t value);
std::map<std::string, std::function<void(double)>> double_options; void AddDoubleOption(const std::string &option, double value);
std::map<std::string, std::function<void(std::string)>> string_options; void AddStringOption(const std::string &option, const std::string &value);
std::map<std::string, void InsertStringOption(const std::string &option, const std::string &value);
std::function<void(std::pair<std::string, std::string>)>> void InsertStringPairOption(const std::string &option, const std::string &key,
container_options; const std::string &value);
void SetTensorLocation(const std::string &tensor, const std::string &option,
std::uint64_t value);
void AddCustomOp(const std::string &paddle_op, const std::string &popart_op,
const std::string &domain, int version);
std::map<std::string, std::function<std::string()>> options_getter; std::string GetOption(const std::string &);
std::map<std::string, std::function<std::vector<std::string>()>> std::vector<std::string> GetVectorOption(const std::string &);
vector_options_getter; std::map<std::string, std::string> GetMapOption(const std::string &);
std::map<std::string, std::function<std::map<std::string, std::string>()>> std::string GetOptionType(const std::string &);
map_options_getter; std::vector<std::string> GetAllOptionNames();
std::map<std::string, std::string> options_type;
void EnablePattern(const std::string &t);
void DisablePattern(const std::string &t);
const bool IsPatternEnabled(const std::string &t);
private:
template <typename ValueType> template <typename ValueType>
void set( void set(
const std::string &key, ValueType value, const std::string &key, ValueType value,
...@@ -117,27 +126,20 @@ struct IpuStrategy { ...@@ -117,27 +126,20 @@ struct IpuStrategy {
return it->second(); return it->second();
} }
public: std::map<std::string, std::function<void(bool)>> bool_options;
void AddBoolOption(const std::string &option, bool value); std::map<std::string, std::function<void(std::uint64_t)>> uint64_options;
void AddUint64Option(const std::string &option, std::uint64_t value); std::map<std::string, std::function<void(double)>> double_options;
void AddDoubleOption(const std::string &option, double value); std::map<std::string, std::function<void(std::string)>> string_options;
void AddStringOption(const std::string &option, const std::string &value); std::map<std::string,
void InsertStringOption(const std::string &option, const std::string &value); std::function<void(std::pair<std::string, std::string>)>>
void InsertStringPairOption(const std::string &option, const std::string &key, container_options;
const std::string &value);
void SetTensorLocation(const std::string &tensor, const std::string &option,
std::uint64_t value);
void AddCustomOp(const std::string &paddle_op, const std::string &popart_op,
const std::string &domain, int version);
std::string GetOption(const std::string &);
std::vector<std::string> GetVectorOption(const std::string &);
std::map<std::string, std::string> GetMapOption(const std::string &);
std::string GetOptionType(const std::string &);
void EnablePattern(const std::string &t); std::map<std::string, std::function<std::string()>> options_getter;
void DisablePattern(const std::string &t); std::map<std::string, std::function<std::vector<std::string>()>>
const bool IsPatternEnabled(const std::string &t); vector_options_getter;
std::map<std::string, std::function<std::map<std::string, std::string>()>>
map_options_getter;
std::map<std::string, std::string> options_type;
}; };
} // namespace ipu } // namespace ipu
......
...@@ -3919,6 +3919,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -3919,6 +3919,8 @@ All parameter, weight, gradient are variables in Paddle.
} }
return res; return res;
}) })
.def("get_all_option_names",
&platform::ipu::IpuStrategy::GetAllOptionNames)
.def("enable_pattern", &platform::ipu::IpuStrategy::EnablePattern) .def("enable_pattern", &platform::ipu::IpuStrategy::EnablePattern)
.def("disable_pattern", &platform::ipu::IpuStrategy::DisablePattern) .def("disable_pattern", &platform::ipu::IpuStrategy::DisablePattern)
.def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled); .def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled);
......
...@@ -46,6 +46,8 @@ class ProtoArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -46,6 +46,8 @@ class ProtoArgumentMappingContext : public phi::ArgumentMappingContext {
bool IsDenseTensorOutput(const std::string& name) const override; bool IsDenseTensorOutput(const std::string& name) const override;
bool IsSelectedRowsOutput(const std::string& name) const override; bool IsSelectedRowsOutput(const std::string& name) const override;
bool IsForInferShape() const override { return false; }
private: private:
mlir::Operation* op_; mlir::Operation* op_;
const std::unordered_map<std::string, uint8_t>& input_map_; const std::unordered_map<std::string, uint8_t>& input_map_;
......
...@@ -91,6 +91,10 @@ class ArgumentMappingContext { ...@@ -91,6 +91,10 @@ class ArgumentMappingContext {
virtual bool IsDenseTensorOutput(const std::string& name) const = 0; virtual bool IsDenseTensorOutput(const std::string& name) const = 0;
virtual bool IsSelectedRowsOutput(const std::string& name) const = 0; virtual bool IsSelectedRowsOutput(const std::string& name) const = 0;
// use this function to mark it comes from InferShapeArgumentMappingContext
// and will be used in infershape
virtual bool IsForInferShape() const = 0;
}; };
} // namespace phi } // namespace phi
...@@ -375,7 +375,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -375,7 +375,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
ReshapeInferMeta(x, shape, out, config); ReshapeInferMeta(x, shape, out, config);
} }
/* Why not use ReduceInferMeta directly? /* Why not use ReduceInferMetaBase directly?
Because we need make InferMetaFunction's args follow the design of api.yaml Because we need make InferMetaFunction's args follow the design of api.yaml
*/ */
void SumInferMeta(const MetaTensor& x, void SumInferMeta(const MetaTensor& x,
...@@ -383,22 +383,53 @@ void SumInferMeta(const MetaTensor& x, ...@@ -383,22 +383,53 @@ void SumInferMeta(const MetaTensor& x,
DataType dtype, DataType dtype,
bool keep_dim, bool keep_dim,
MetaTensor* out) { MetaTensor* out) {
ReduceInferMetaBase(x, axis, keep_dim, dtype, out); bool reduce_all = false;
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, dtype, out);
} }
void ReduceInferMetaBase(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
bool reduce_all,
DataType dtype, DataType dtype,
MetaTensor* out) { MetaTensor* out) {
bool reduce_all = true; auto x_rank = x.dims().size();
std::set<int64_t> dims_set(axis.begin(), axis.end());
std::vector<int64_t> formated_axis = axis;
for (size_t i = 0; i < axis.size(); ++i) {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
if (axis[i] < 0) {
formated_axis[i] = axis[i] + x_rank;
}
}
bool full_dim = true;
std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end());
for (int64_t i = 0; i < x.dims().size(); ++i) { for (int64_t i = 0; i < x.dims().size(); ++i) {
if (dims_set.find(i) == dims_set.end()) { if (dims_set.find(i) == dims_set.end()) {
reduce_all = false; full_dim = false;
break; break;
} }
} }
reduce_all = reduce_all || full_dim;
std::vector<int64_t> out_dim_vector; std::vector<int64_t> out_dim_vector;
if (keep_dim) { if (keep_dim) {
...@@ -441,11 +472,20 @@ void ReduceInferMetaBase(const MetaTensor& x, ...@@ -441,11 +472,20 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void ReduceInferMeta(const MetaTensor& x, void MeanRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out) {
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out);
}
void MeanInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
MetaTensor* out) { MetaTensor* out) {
ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out); bool reduce_all = false;
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out);
} }
void TransferLayoutInferMeta(const MetaTensor& x, void TransferLayoutInferMeta(const MetaTensor& x,
......
...@@ -86,10 +86,17 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -86,10 +86,17 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
void ReduceInferMetaBase(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
bool reduce_all,
DataType dtype, DataType dtype,
MetaTensor* out); MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x, void MeanRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out);
void MeanInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
MetaTensor* out); MetaTensor* out);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selu_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/selu_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
selu_grad, CPU, ALL_LAYOUT, phi::SeluGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selu_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/selu_kernel_impl.h"
PD_REGISTER_KERNEL(selu, CPU, ALL_LAYOUT, phi::SeluKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selu_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/selu_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
selu_grad, GPU, ALL_LAYOUT, phi::SeluGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selu_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/selu_kernel_impl.h"
PD_REGISTER_KERNEL(selu, GPU, ALL_LAYOUT, phi::SeluKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/impl/selu_kernel_impl.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SeluGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
float scale,
float alpha,
DenseTensor* dx) {
auto dx_ptr = dev_ctx.template Alloc<T>(dx);
SeluGradFunctor<T> functor(
out.data<T>(), dout.data<T>(), alpha, scale, dx_ptr);
size_t limit = static_cast<size_t>(out.numel());
paddle::platform::ForRange<Context> for_range(dev_ctx, limit);
for_range(functor);
}
} // namespace phi
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle { namespace phi {
namespace operators {
template <typename T> template <typename T>
struct SeluFunctor { struct SeluFunctor {
...@@ -32,7 +31,7 @@ struct SeluFunctor { ...@@ -32,7 +31,7 @@ struct SeluFunctor {
HOSTDEVICE void operator()(size_t idx) const { HOSTDEVICE void operator()(size_t idx) const {
T x_ele = x_data_ptr_[idx]; T x_ele = x_data_ptr_[idx];
if (x_ele <= 0) { if (x_ele <= 0) {
x_ele = alpha_ * real_exp(x_ele) - alpha_; x_ele = alpha_ * paddle::operators::real_exp(x_ele) - alpha_;
} }
y_data_ptr_[idx] = scale_ * x_ele; y_data_ptr_[idx] = scale_ * x_ele;
} }
...@@ -44,8 +43,11 @@ struct SeluFunctor { ...@@ -44,8 +43,11 @@ struct SeluFunctor {
template <typename T> template <typename T>
struct SeluGradFunctor { struct SeluGradFunctor {
SeluGradFunctor(const T* y_data_ptr, const T* dy_data_ptr, float alpha, SeluGradFunctor(const T* y_data_ptr,
float scale, T* dx_data_ptr) const T* dy_data_ptr,
float alpha,
float scale,
T* dx_data_ptr)
: y_data_ptr_(y_data_ptr), : y_data_ptr_(y_data_ptr),
dy_data_ptr_(dy_data_ptr), dy_data_ptr_(dy_data_ptr),
alpha_(alpha), alpha_(alpha),
...@@ -71,53 +73,16 @@ struct SeluGradFunctor { ...@@ -71,53 +73,16 @@ struct SeluGradFunctor {
T* dx_data_ptr_; T* dx_data_ptr_;
}; };
template <typename DeviceContext, typename T> template <typename T, typename Context>
class SeluKernel : public framework::OpKernel<T> { void SeluKernel(const Context& dev_ctx,
public: const DenseTensor& x,
void Compute(const framework::ExecutionContext& context) const override { float scale,
using Tensor = framework::Tensor; float alpha,
DenseTensor* out) {
auto* x = context.Input<Tensor>("X"); auto out_ptr = dev_ctx.template Alloc<T>(out);
auto* out = context.Output<Tensor>("Out"); SeluFunctor<T> functor(x.data<T>(), alpha, scale, out_ptr);
size_t limit = static_cast<size_t>(x.numel());
float alpha = context.Attr<float>("alpha"); paddle::platform::ForRange<Context> for_range(dev_ctx, limit);
float scale = context.Attr<float>("scale");
auto out_ptr = out->mutable_data<T>(context.GetPlace());
SeluFunctor<T> functor(x->data<T>(), alpha, scale, out_ptr);
auto& dev_ctx = context.template device_context<DeviceContext>();
size_t limit = static_cast<size_t>(x->numel());
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(functor);
}
};
template <typename DeviceContext, typename T>
class SeluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using Tensor = framework::Tensor;
auto* out = context.Input<Tensor>("Out");
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
float alpha = context.Attr<float>("alpha");
float scale = context.Attr<float>("scale");
auto dx_ptr = dx->mutable_data<T>(context.GetPlace());
SeluGradFunctor<T> functor(out->data<T>(), dout->data<T>(), alpha, scale,
dx_ptr);
auto& dev_ctx = context.template device_context<DeviceContext>();
size_t limit = static_cast<size_t>(out->numel());
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(functor); for_range(functor);
} }
}; } // namespace phi
} // namespace operators
} // namespace paddle
...@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx, ...@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx,
bool keep_dim) { bool keep_dim) {
auto dense_out = phi::Empty<T, Context>(dev_ctx); auto dense_out = phi::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out); ReduceInferMetaBase(x, axis, keep_dim, false, x.dtype(), &meta_out);
MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out); MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -136,7 +136,9 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { ...@@ -136,7 +136,9 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
return shared_memory[threadIdx.x]; return shared_memory[threadIdx.x];
} }
// Swap data /**
* @brief Swap data
*/
template <typename T> template <typename T>
__device__ __forceinline__ void Swap(T* first_value, T* second_value) { __device__ __forceinline__ void Swap(T* first_value, T* second_value) {
T t_value; T t_value;
...@@ -145,7 +147,9 @@ __device__ __forceinline__ void Swap(T* first_value, T* second_value) { ...@@ -145,7 +147,9 @@ __device__ __forceinline__ void Swap(T* first_value, T* second_value) {
(*second_value) = t_value; (*second_value) = t_value;
} }
// swap with monotonic_type /**
* @brief Swap data according to monotonic_type.
*/
template <typename T> template <typename T>
__device__ __forceinline__ void Comparator(T* first_value, __device__ __forceinline__ void Comparator(T* first_value,
T* second_value, T* second_value,
...@@ -155,6 +159,9 @@ __device__ __forceinline__ void Comparator(T* first_value, ...@@ -155,6 +159,9 @@ __device__ __forceinline__ void Comparator(T* first_value,
} }
} }
/**
* @brief Swap data and data index according to monotonic_type.
*/
template <typename T, typename IndexType> template <typename T, typename IndexType>
__device__ __forceinline__ void ComparatorWithIndex(T* first_value, __device__ __forceinline__ void ComparatorWithIndex(T* first_value,
...@@ -170,6 +177,18 @@ __device__ __forceinline__ void ComparatorWithIndex(T* first_value, ...@@ -170,6 +177,18 @@ __device__ __forceinline__ void ComparatorWithIndex(T* first_value,
} }
} }
/**
* @brief get the last pow of 2
*/
__device__ inline int GetLastPow2(int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return std::max(1, n - (n >> 1));
}
} // namespace details } // namespace details
/** /**
...@@ -453,6 +472,29 @@ __device__ __forceinline__ void Reduce(T* out, ...@@ -453,6 +472,29 @@ __device__ __forceinline__ void Reduce(T* out,
} }
} }
/*
* @brief Fill register with a constant according to OpFunc
*
* @template paraments
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. Currently only
* GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE InT operator()()
* const {
* return a;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template <typename InT, template <typename InT,
typename OutT, typename OutT,
int NX, int NX,
...@@ -466,6 +508,33 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { ...@@ -466,6 +508,33 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
} }
} }
/*
* @brief Get ReturnsCount random data fromm compute according to state, state
* can be curandStatePhilox4_32_10_t, hiprandStatePhilox4_32_10_t which has beed
* initialized.
*
* @template paraments
* StateType: the type of state, can be curandStatePhilox4_32_10_t or
* hiprandStatePhilox4_32_10_t.
* OutT: the type of out register.
* ReturnsCount: The number of random data generated by OpFunc.
* BlockSize: Identifies the current device thread index method. Currently only
* GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename T>
* struct XxxFunctor {
* HOSTDEVICE InT operator()(StateType state)
* const {
* return ranomd(state); // Returns ReturnsCount random numbers with
* data type T
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* compute: Compute function which was declared like OpFunc<T>().
*/
template <typename StateType, template <typename StateType,
typename OutT, typename OutT,
int ReturnsCount, int ReturnsCount,
...@@ -481,108 +550,171 @@ __device__ __forceinline__ void ElementwiseRandom(OutT* out, ...@@ -481,108 +550,171 @@ __device__ __forceinline__ void ElementwiseRandom(OutT* out,
} }
} }
// attention please set share_size = blockDim.x; /*
// data and b are the register pointer * @brief Complete the prefix and in the block, each thread calculates 2 data,
#define shared_size 64 * the size of out and in is 2, and BlockDim.x must be less then 512.
template <typename InT, *
typename OutT, * @template paraments
int NX, * InT: the type of input register.
int NY, * OutT: the type of out register.
int BlockSize, * BlockSize: Identifies the current device thread index method. Currently only
class OpFunc> * GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename T>
* struct XxxFunctor {
* HOSTDEVICE InT operator()(T a, T b)
* const {
* return a + b;
* }
* };
*
* @param
* out: The register pointer of out, the size is 2;
* in: The register pointer of input, the size is 2;
* compute: Compute function which was declared like OpFunc<T>().
*/
#define SHARED_SIZE_LIMIT 512
template <typename InT, typename OutT, int BlockSize, class OpFunc>
__device__ __forceinline__ void Cumsum(OutT* out, __device__ __forceinline__ void Cumsum(OutT* out,
const InT* in, const InT* in,
OpFunc compute) { OpFunc compute) {
__shared__ InT temp[shared_size * 2 + (shared_size * 2) / 32]; constexpr int kSize = SHARED_SIZE_LIMIT * 2 + (SHARED_SIZE_LIMIT * 2) / 32;
__shared__ InT temp[kSize];
int stride_size = blockDim.x;
int tidx = threadIdx.x; int tidx = threadIdx.x;
temp[tidx + tidx / 32] = in[0]; temp[tidx + tidx / 32] = in[0];
temp[shared_size + tidx + (shared_size + tidx) / 32] = in[1]; temp[stride_size + tidx + (stride_size + tidx) / 32] = in[1];
for (int stride = 1; stride <= blockDim.x; stride *= 2) { for (int stride = 1; stride <= stride_size; stride *= 2) {
__syncthreads(); __syncthreads();
int index = (tidx + 1) * 2 * stride - 1; int index = (tidx + 1) * 2 * stride - 1;
if (index < (blockDim.x * 2)) { if (index < (blockDim.x * 2)) {
temp[index + index / 32] += temp[index - stride + (index - stride) / 32]; temp[index + index / 32] =
compute(temp[index + index / 2],
temp[index - stride + (index - stride) / 32]);
} }
} }
for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) { for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) {
__syncthreads(); __syncthreads();
int index = (tidx + 1) * 2 * stride - 1; int index = (tidx + 1) * 2 * stride - 1;
if ((index + stride) < (blockDim.x * 2)) { if ((index + stride) < (blockDim.x * 2)) {
temp[index + stride + (stride + index) / 32] += temp[index + stride + (stride + index) / 32] =
temp[index + (index) / 32]; compute(temp[index + stride + (stride + index) / 32],
temp[index + (index) / 32]);
} }
} }
__syncthreads(); __syncthreads();
out[0] = static_cast<OutT>(temp[tidx + tidx / 32]); out[0] = static_cast<OutT>(temp[tidx + tidx / 32]);
out[1] = out[1] =
static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]); static_cast<OutT>(temp[tidx + stride_size + (tidx + stride_size) / 32]);
} }
#undef SHARED_SIZE_LIMIT
#define SHARED_SIZE_LIMIT \
1024 // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must /*
// larger than blockDim.x * 2 * @brief Sort data in this block, each thread calculates 2 data, the size of out
// if monotonic_type = 1 then increase * and in is 2, and BlockDim.x must be less then 512.
// if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2 *
// == 1 the increase * @template paraments
template <typename T> * InT: the type of input register.
__device__ __forceinline__ void Sort(T* dst, * OutT: the type of out register.
const T* src_data, * BlockSize: Identifies the current device thread index method. Currently only
* GPU was supported.
*
* @param
* out: The register pointer of out, the size is 2.
* in: The register pointer of input, the size is 2.
* num: The num of this block
* monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles
* sorted in escending.
*/
#define SHARED_SIZE_LIMIT 1024
// each thread load 2 data from global memory so SHARED_SIZE_LIMIT must
// larger than blockDim.x * 2
template <typename InT, typename OutT, int BlockSize>
__device__ __forceinline__ void Sort(OutT* out,
const InT* in,
int num, int num,
int monotonic_type) { int monotonic_type) {
// todo: set num = Pow2(num) int upper_bound = blockDim.x;
// update upper_bound
upper_bound = std::min(details::GetLastPow2(num), upper_bound);
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2 // shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than __shared__ InT value[SHARED_SIZE_LIMIT];
// blockDim * 2 int stride_size = blockDim.x;
// Copy value and index from src and src_index // shareMem's size must larger than blockDim * 2
value[threadIdx.x] = src_data[0]; // Copy value from in
value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1]; value[threadIdx.x] = in[0];
value[threadIdx.x + stride_size] = in[1];
// make bitonicSort // make bitonicSort
for (int size = 2; size < num; size <<= 1) { for (int size = 2; size < upper_bound; size <<= 1) {
int bitonic_type = (threadIdx.x & (size / 2)) != 0; int bitonic_type = (threadIdx.x & (size / 2)) != 0;
for (int stride = size / 2; stride > 0; stride >>= 1) { for (int stride = size / 2; stride > 0; stride >>= 1) {
__syncthreads(); __syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
details::Comparator<T>(&value[pos], &value[pos + stride], bitonic_type); details::Comparator<InT>(&value[pos], &value[pos + stride], bitonic_type);
} }
} }
// last sort // last sort
for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) { for (int stride = stride_size; stride > 0; stride >>= 1) {
__syncthreads(); __syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
// last sort when monotonic_type = 1 then increase // last sort when monotonic_type = 1 then increase
details::Comparator<T>(&value[pos], &value[pos + stride], monotonic_type); details::Comparator<InT>(&value[pos], &value[pos + stride], monotonic_type);
} }
__syncthreads(); __syncthreads();
dst[0] = value[threadIdx.x]; out[0] = static_cast<OutT>(value[threadIdx.x]);
dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; out[1] = static_cast<OutT>(value[threadIdx.x + stride_size]);
} }
template <typename T, typename IndexType> /*
__device__ __forceinline__ void Sort(T* dst, * @brief Sort data with data_index in this block, each thread calculates 2 data,
IndexType* dst_index, * the size of out and in is 2, and BlockDim.x must be less then 512.
const T* src_data, *
IndexType* src_index, * @template paraments
* InT: The type of input register.
* OutT: The type of out register.
* IndexType: The type of index.
* BlockSize: Identifies the current device thread index method. Currently only
* GPU was supported.
*
* @param
* out: The register pointer of out, the size is 2.
* out_index: The register pointer of out_index, the size is 2.
* in: The register pointer of input, the size is 2.
* in_index: The register pointer of in_index, the size is 2.
* num: The num of this block.
* monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles
* sorted in escending.
*/
template <typename InT, typename OutT, typename IndexType, int BlockSize>
__device__ __forceinline__ void Sort(OutT* out,
IndexType* out_index,
const InT* in,
IndexType* in_index,
int num, int num,
int monotonic_type) { int monotonic_type) {
// todo: set num = Pow2(num) int upper_bound = blockDim.x;
// update upper_bound
upper_bound = std::min(details::GetLastPow2(num), upper_bound);
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2 // shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than __shared__ InT value[SHARED_SIZE_LIMIT];
// blockDim * 2 // shareMem's size must larger than blockDim * 2
__shared__ IndexType index[SHARED_SIZE_LIMIT]; __shared__ IndexType index[SHARED_SIZE_LIMIT];
// Copy value and index from src and src_index // Copy value and index from in and in_index
value[threadIdx.x] = src_data[0]; int stride_size = blockDim.x;
value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1]; value[threadIdx.x] = in[0];
value[threadIdx.x + stride_size] = in[1];
// index // index
index[threadIdx.x] = src_index[0]; index[threadIdx.x] = in_index[0];
index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_index[1]; index[threadIdx.x + stride_size] = in_index[1];
// make bitonicSort // make bitonicSort
for (int size = 2; size < num; size <<= 1) { for (int size = 2; size < upper_bound; size <<= 1) {
int bitonic_type = (threadIdx.x & (size / 2)) != 0; int bitonic_type = (threadIdx.x & (size / 2)) != 0;
for (int stride = size / 2; stride > 0; stride >>= 1) { for (int stride = size / 2; stride > 0; stride >>= 1) {
__syncthreads(); __syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
details::ComparatorWithIndex<T, IndexType>(&value[pos], details::ComparatorWithIndex<InT, IndexType>(&value[pos],
&value[pos + stride], &value[pos + stride],
&index[pos], &index[pos],
&index[pos + stride], &index[pos + stride],
...@@ -590,11 +722,11 @@ __device__ __forceinline__ void Sort(T* dst, ...@@ -590,11 +722,11 @@ __device__ __forceinline__ void Sort(T* dst,
} }
} }
for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) { for (int stride = stride_size; stride > 0; stride >>= 1) {
__syncthreads(); __syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
// last sort when monotonic_type = 1 then increase // last sort when monotonic_type = 1 then increase
details::ComparatorWithIndex<T, IndexType>(&value[pos], details::ComparatorWithIndex<InT, IndexType>(&value[pos],
&value[pos + stride], &value[pos + stride],
&index[pos], &index[pos],
&index[pos + stride], &index[pos + stride],
...@@ -602,10 +734,24 @@ __device__ __forceinline__ void Sort(T* dst, ...@@ -602,10 +734,24 @@ __device__ __forceinline__ void Sort(T* dst,
} }
__syncthreads(); __syncthreads();
dst[0] = value[threadIdx.x]; out[0] = static_cast<OutT>(value[threadIdx.x]);
dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; out[1] = static_cast<OutT>(value[threadIdx.x + stride_size]);
dst_index[0] = index[threadIdx.x]; out_index[0] = index[threadIdx.x];
dst_index[1] = index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; out_index[1] = index[threadIdx.x + stride_size];
}
template <typename T1, typename T2, typename OutT, typename OpFunc>
HOSTDEVICE __forceinline__ void OperatorTernary(
OutT* out, const T1* in1, const T2* in2, OpFunc func, int num) {
func(out, in1, in2, num);
}
template <typename InT, typename OutT, typename OpFunc>
HOSTDEVICE __forceinline__ void OperatorBinary(OutT* out,
const InT* in,
OpFunc func,
int num) {
func(out, in, num);
} }
} // namespace kps } // namespace kps
......
...@@ -348,6 +348,29 @@ __device__ __forceinline__ void Reduce(T* out, ...@@ -348,6 +348,29 @@ __device__ __forceinline__ void Reduce(T* out,
} }
} }
/*
* @brief Fill register with a constant according to OpFunc
*
* @template paraments
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE InT operator()()
* const {
* return a;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template <typename InT, template <typename InT,
typename OutT, typename OutT,
int NX, int NX,
......
...@@ -297,6 +297,24 @@ __device__ __forceinline__ void ReadData(T* dst, ...@@ -297,6 +297,24 @@ __device__ __forceinline__ void ReadData(T* dst,
/** /**
* @brief Read 1D data from global memory to register. The difference * @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs. * from the above function is that it supports different data types of inputs.
*
* @template paraments
* T: The type of data.
* NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported.
* ArgsT: The Type if dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* Index: The index of data stored in dst.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The data pointer of the current block.
* size: The current block needs to load size data continuously.
*/ */
template <typename T, template <typename T,
int NX, int NX,
...@@ -714,6 +732,20 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -714,6 +732,20 @@ __device__ __forceinline__ void ReadDataBc(
} }
} }
/**
* @brief Initialize register with data index.
*
* @template paraments
* T: Data type of register.
* NX: Number of data to initialize.
* NY: Number of data to initialize, NY only can be 1.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index. Currently only GPU was supported.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* init_data: The register pointer of init data, the size is NX.
*/
template <typename T, int NX, int NY, int BlockSize> template <typename T, int NX, int NY, int BlockSize>
__device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) { __device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) {
int thread_offset = block_offset + threadIdx.x * NX; int thread_offset = block_offset + threadIdx.x * NX;
......
...@@ -244,6 +244,24 @@ __device__ __inline__ void ReadData(T* dst, ...@@ -244,6 +244,24 @@ __device__ __inline__ void ReadData(T* dst,
/** /**
* @brief Read 1D data from global memory to register. The difference * @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs. * from the above function is that it supports different data types of inputs.
*
* @template paraments
* T: The type of data.
* NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported.
* ArgsT: The Type if dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* Index: The index of data stored in dst.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The data pointer of the current block.
* size: The current block needs to load size data continuously.
*/ */
template <typename T, template <typename T,
int NX, int NX,
...@@ -646,5 +664,28 @@ __device__ __inline__ void ReadDataBc( ...@@ -646,5 +664,28 @@ __device__ __inline__ void ReadDataBc(
} }
} }
/**
* @brief Initialize register with data index.
*
* @template paraments
* T: Data type of register.
* NX: Number of data to initialize.
* NY: Number of data to initialize, NY only can be 1.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* init_data: The register pointer of init data, the size is NX.
*/
template <typename T, int NX, int NY, int BlockSize>
__device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) {
int thread_offset = block_offset + core_id() * NX;
#pragma unroll
for (int nx = 0; nx < NX; ++nx) {
dst[nx] = static_cast<T>(thread_offset + nx);
}
}
} // namespace kps } // namespace kps
} // namespace phi } // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SeluGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& d_out,
float scale,
float alpha,
DenseTensor* d_x);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SeluKernel(const Context& dev_ctx,
const DenseTensor& x,
float scale,
float alpha,
DenseTensor* out);
} // phi
...@@ -17,29 +17,37 @@ limitations under the License. */ ...@@ -17,29 +17,37 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
if (ctx.IsDenseTensorInput("X")) { if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) { bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
return KernelSignature( // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
"sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"}); // InferShape, so we must return the "sum_raw" KernelSignature.
} // And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with
// the "sum_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature("sum_raw", return KernelSignature("sum_raw",
{"X"}, {"X"},
{"dim", "keep_dim", "reduce_all", "out_dtype"}, {"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"}); {"Out"});
} }
return KernelSignature(
"sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("unregistered", {}, {}, {});
} }
KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
if (ctx.IsDenseTensorInput("X")) { if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) { bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"}); // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
} // InferShape, so we must return the "mean_raw" KernelSignature.
// And the InferMeta function(i.e. MeanRawInferMeta) is accordance with the
// "mean_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature( return KernelSignature(
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
} }
return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("unregistered", {}, {}, {});
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SeluGradGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("selu_grad",
{"Out", GradVarName("Out")},
{"scale", "alpha"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(selu_grad, phi::SeluGradGradOpArgumentMapping);
...@@ -80,6 +80,8 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -80,6 +80,8 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return selected_rows_outputs.count(name) > 0; return selected_rows_outputs.count(name) > 0;
} }
bool IsForInferShape() const override { return false; }
private: private:
const std::unordered_set<std::string> dense_tensor_inputs; const std::unordered_set<std::string> dense_tensor_inputs;
const std::unordered_set<std::string> selected_rows_inputs; const std::unordered_set<std::string> selected_rows_inputs;
......
...@@ -1266,7 +1266,7 @@ function card_test() { ...@@ -1266,7 +1266,7 @@ function card_test() {
elif [ "${WITH_ASCEND_CL}" == "ON" ];then elif [ "${WITH_ASCEND_CL}" == "ON" ];then
CUDA_DEVICE_COUNT=1 CUDA_DEVICE_COUNT=1
elif [ "${WITH_ROCM}" == "ON" ];then elif [ "${WITH_ROCM}" == "ON" ];then
CUDA_DEVICE_COUNT=4 CUDA_DEVICE_COUNT=$(rocm-smi -i | grep GPU | wc -l)
else else
CUDA_DEVICE_COUNT=$(nvidia-smi -L | wc -l) CUDA_DEVICE_COUNT=$(nvidia-smi -L | wc -l)
fi fi
......
...@@ -25,6 +25,12 @@ function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_pa ...@@ -25,6 +25,12 @@ function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_pa
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True) _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True)
endfunction() endfunction()
function(download_data install_dir url data_file check_sum)
if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${url} ${data_file} ${check_sum})
endif()
endfunction()
function(download_quant_data install_dir data_file check_sum) function(download_quant_data install_dir data_file check_sum)
if (NOT EXISTS ${install_dir}/${data_file}) if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file} ${check_sum}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file} ${check_sum})
...@@ -290,8 +296,9 @@ if(LINUX AND WITH_MKLDNN) ...@@ -290,8 +296,9 @@ if(LINUX AND WITH_MKLDNN)
### PTQ INT8 ### PTQ INT8
# PTQ int8 lstm model # PTQ int8 lstm model
set(LSTM_DATA_ARCHIVE "unittest_model_data/quant_lstm_input_data.tar.gz") set(LSTM_DATA_FILE "quant_lstm_input_data.tar.gz")
download_quant_data(${QUANT2_INT8_LSTM_SAVE_PATH} ${LSTM_DATA_ARCHIVE} add84c754e9b792fea1fbd728d134ab7) set(LSTM_URL "${INFERENCE_URL}/int8/unittest_model_data")
download_data(${QUANT2_INT8_LSTM_SAVE_PATH} ${LSTM_URL} ${LSTM_DATA_FILE} add84c754e9b792fea1fbd728d134ab7)
set(QUANT2_FP32_LSTM_MODEL_ARCHIVE "lstm_fp32_model.tar.gz") set(QUANT2_FP32_LSTM_MODEL_ARCHIVE "lstm_fp32_model.tar.gz")
download_lstm_model(${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_FP32_LSTM_MODEL_ARCHIVE} eecd9f44d69a84acc1cf2235c4b8b743) download_lstm_model(${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_FP32_LSTM_MODEL_ARCHIVE} eecd9f44d69a84acc1cf2235c4b8b743)
inference_quant2_int8_lstm_model_test(test_quant2_int8_lstm_mkldnn ${QUANT2_INT8_LSTM_SAVE_PATH}/lstm_fp32_model ${QUANT2_LSTM_MODEL_DIR}/lstm_quant ${QUANT2_INT8_LSTM_SAVE_PATH}/quant_lstm_input_data) inference_quant2_int8_lstm_model_test(test_quant2_int8_lstm_mkldnn ${QUANT2_INT8_LSTM_SAVE_PATH}/lstm_fp32_model ${QUANT2_LSTM_MODEL_DIR}/lstm_quant ${QUANT2_INT8_LSTM_SAVE_PATH}/quant_lstm_input_data)
......
if(WITH_IPU)
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
endif()
...@@ -423,6 +423,14 @@ class TestMoveAxis(unittest.TestCase): ...@@ -423,6 +423,14 @@ class TestMoveAxis(unittest.TestCase):
self.assertEqual(np.array_equal(out.numpy(), expected), True) self.assertEqual(np.array_equal(out.numpy(), expected), True)
paddle.enable_static() paddle.enable_static()
def test_moveaxis3(self):
paddle.disable_static()
x = paddle.to_tensor(
[[1 + 1j, -1 - 1j], [1 + 1j, -1 - 1j], [1 + 1j, -1 - 1j]])
out = x.moveaxis(0, 1)
self.assertEqual(out.shape, [2, 3])
paddle.enable_static()
def test_error(self): def test_error(self):
x = paddle.randn([2, 3, 4, 5]) x = paddle.randn([2, 3, 4, 5])
# src must have the same number with dst # src must have the same number with dst
......
...@@ -51,6 +51,10 @@ class TestVarBase(unittest.TestCase): ...@@ -51,6 +51,10 @@ class TestVarBase(unittest.TestCase):
np.array_equal(x.numpy(), np.array([1.2], 'float16'))) np.array_equal(x.numpy(), np.array([1.2], 'float16')))
self.assertEqual(x.dtype, core.VarDesc.VarType.FP16) self.assertEqual(x.dtype, core.VarDesc.VarType.FP16)
# set_default_dtype take effect on int
x = paddle.to_tensor(1, place=place)
self.assertTrue(x.dtype, core.VarDesc.VarType.INT64)
# set_default_dtype take effect on float # set_default_dtype take effect on float
x = paddle.to_tensor(1.2, place=place, stop_gradient=False) x = paddle.to_tensor(1.2, place=place, stop_gradient=False)
self.assertTrue( self.assertTrue(
......
...@@ -110,12 +110,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -110,12 +110,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace" "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace"
) )
#Todo(zhouwei): Support allocate tensor on any other specified card
if isinstance(place, core.CUDAPlace) and isinstance(
_current_expected_place(), core.CUDAPlace) and place._get_device_id(
) != _current_expected_place()._get_device_id():
place = _current_expected_place()
if not isinstance(data, np.ndarray): if not isinstance(data, np.ndarray):
def _handle_dtype(data, dtype): def _handle_dtype(data, dtype):
...@@ -139,7 +133,7 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -139,7 +133,7 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
data.stop_gradient = stop_gradient data.stop_gradient = stop_gradient
return data return data
elif isinstance(data, (core.LoDTensor, core.Tensor)): elif isinstance(data, (core.LoDTensor, core.Tensor)):
# Note(zhouwei25): should't expose it to users, just for internal use. # should't expose it to users, just for internal use.
# convert core.Tensor/core.LoDTensor to VarBase first # convert core.Tensor/core.LoDTensor to VarBase first
# Currenly, there is no copy when places are same # Currenly, there is no copy when places are same
data = paddle.Tensor(data) data = paddle.Tensor(data)
...@@ -152,7 +146,8 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -152,7 +146,8 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
raise TypeError( raise TypeError(
"Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor". "Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor".
format(type(data))) format(type(data)))
if not dtype and data.dtype in [ if not dtype:
if data.dtype in [
'float16', 'float32', 'float64', 'complex64', 'complex128' 'float16', 'float32', 'float64', 'complex64', 'complex128'
]: ]:
default_type = paddle.get_default_dtype() default_type = paddle.get_default_dtype()
...@@ -161,6 +156,10 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -161,6 +156,10 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
'float16', 'float32' 'float16', 'float32'
] else 'complex128' ] else 'complex128'
data = data.astype(default_type) data = data.astype(default_type)
# Windows default type is 'int32', while Linux/Mac is 'int64'. Unify they.
if data.dtype in ['int32']:
default_type = "int64"
data = data.astype(default_type)
if dtype and convert_dtype(dtype) != data.dtype: if dtype and convert_dtype(dtype) != data.dtype:
data = data.astype(convert_dtype(dtype)) data = data.astype(convert_dtype(dtype))
......
...@@ -2737,9 +2737,10 @@ def moveaxis(x, source, destination, name=None): ...@@ -2737,9 +2737,10 @@ def moveaxis(x, source, destination, name=None):
out, _ = _C_ops.transpose2(x, 'axis', perm) out, _ = _C_ops.transpose2(x, 'axis', perm)
return out return out
check_variable_and_dtype( check_variable_and_dtype(x, 'x', [
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
'moveaxis') 'complex128'
], 'moveaxis')
helper = LayerHelper('moveaxis', **locals()) helper = LayerHelper('moveaxis', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
......
...@@ -124,7 +124,7 @@ ...@@ -124,7 +124,7 @@
args : (Tensor x, int64_t[] axis={}, bool keep_dim=false) args : (Tensor x, int64_t[] axis={}, bool keep_dim=false)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ReduceInferMeta func : MeanInferMeta
kernel : kernel :
func : mean func : mean
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册