提交 943dedec 编写于 作者: P phlrain

add sgd kernel; test=develop

上级 a4bccde0
......@@ -2048,7 +2048,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
// deal with optional here
if ((it == ctx.inputs.end() || it->second.size() == 0) &&
(input_defs[i].type_index ==
std::type_index(typeid(paddle::optional<const phi::DenseTensor&>)))) {
std::type_index(
typeid(paddle::optional<const phi::DenseTensor&>)) ||
input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<const phi::SelectedRows&>)))) {
pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1;
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx),
......
......@@ -81,6 +81,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
paddle::optional<const SelectedRows&>))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type ==
std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput(default_key.backend(),
......
......@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_helper.h"
......@@ -72,7 +73,6 @@ void SGDDenseKernel(const Context& dev_ctx,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* master_param_out) {
LOG(ERROR) << "run here";
using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type;
// do check here
// if (multi_precision) {
......
......@@ -17,9 +17,7 @@
namespace phi {
KernelSignature SGDOpArgumentMapping(const ArgumentMappingContext& ctx) {
LOG(ERROR) << "11";
if (ctx.IsDenseTensorInput("Grad")) {
LOG(ERROR) << "dense";
return KernelSignature("sgd",
{"Param", "LearningRate", "Grad", "MasterParam"},
{"multi_precision"},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册