提交 943dedec 编写于 作者: P phlrain

add sgd kernel; test=develop

上级 a4bccde0
...@@ -2048,7 +2048,11 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2048,7 +2048,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
// deal with optional here // deal with optional here
if ((it == ctx.inputs.end() || it->second.size() == 0) && if ((it == ctx.inputs.end() || it->second.size() == 0) &&
(input_defs[i].type_index == (input_defs[i].type_index ==
std::type_index(typeid(paddle::optional<const phi::DenseTensor&>)))) { std::type_index(
typeid(paddle::optional<const phi::DenseTensor&>)) ||
input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<const phi::SelectedRows&>)))) {
pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr); pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1; auto end_idx = start_idx + 1;
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx),
......
...@@ -81,6 +81,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -81,6 +81,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
} else if (arg_type == std::type_index(typeid(
paddle::optional<const SelectedRows&>))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == } else if (arg_type ==
std::type_index(typeid(const std::vector<DenseTensor>&))) { std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput(default_key.backend(), args_def->AppendInput(default_key.backend(),
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/sgd_kernel.h" #include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_helper.h" #include "paddle/phi/backends/gpu/gpu_helper.h"
...@@ -72,7 +73,6 @@ void SGDDenseKernel(const Context& dev_ctx, ...@@ -72,7 +73,6 @@ void SGDDenseKernel(const Context& dev_ctx,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* master_param_out) { DenseTensor* master_param_out) {
LOG(ERROR) << "run here";
using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type; using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type;
// do check here // do check here
// if (multi_precision) { // if (multi_precision) {
......
...@@ -17,9 +17,7 @@ ...@@ -17,9 +17,7 @@
namespace phi { namespace phi {
KernelSignature SGDOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SGDOpArgumentMapping(const ArgumentMappingContext& ctx) {
LOG(ERROR) << "11";
if (ctx.IsDenseTensorInput("Grad")) { if (ctx.IsDenseTensorInput("Grad")) {
LOG(ERROR) << "dense";
return KernelSignature("sgd", return KernelSignature("sgd",
{"Param", "LearningRate", "Grad", "MasterParam"}, {"Param", "LearningRate", "Grad", "MasterParam"},
{"multi_precision"}, {"multi_precision"},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册