未验证 提交 f3f2ffd4 编写于 作者: R random-thoughtss 提交者: GitHub

Merge branch 'AUTOMATIC1111:master' into master

...@@ -44,7 +44,7 @@ body: ...@@ -44,7 +44,7 @@ body:
id: commit id: commit
attributes: attributes:
label: Commit where the problem happens label: Commit where the problem happens
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI) description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
validations: validations:
required: true required: true
- type: dropdown - type: dropdown
......
此差异已折叠。
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
"Add layer normalization": "레이어 정규화(normalization) 추가", "Add layer normalization": "레이어 정규화(normalization) 추가",
"Add model hash to generation information": "생성 정보에 모델 해시 추가", "Add model hash to generation information": "생성 정보에 모델 해시 추가",
"Add model name to generation information": "생성 정보에 모델 이름 추가", "Add model name to generation information": "생성 정보에 모델 이름 추가",
"Add number to filename when saving": "이미지를 저장할 때 파일명에 숫자 추가하기",
"Aesthetic imgs embedding": "스타일 이미지 임베딩", "Aesthetic imgs embedding": "스타일 이미지 임베딩",
"Aesthetic learning rate": "스타일 학습 수", "Aesthetic learning rate": "스타일 학습 수",
"Aesthetic steps": "스타일 스텝 수", "Aesthetic steps": "스타일 스텝 수",
...@@ -35,6 +36,7 @@ ...@@ -35,6 +36,7 @@
"Apply color correction to img2img results to match original colors.": "이미지→이미지 결과물이 기존 색상과 일치하도록 색상 보정 적용하기", "Apply color correction to img2img results to match original colors.": "이미지→이미지 결과물이 기존 색상과 일치하도록 색상 보정 적용하기",
"Apply selected styles to current prompt": "현재 프롬프트에 선택된 스타일 적용", "Apply selected styles to current prompt": "현재 프롬프트에 선택된 스타일 적용",
"Apply settings": "설정 적용하기", "Apply settings": "설정 적용하기",
"Auto focal point crop": "초점 기준 크롭(자동 감지)",
"Batch count": "배치 수", "Batch count": "배치 수",
"Batch from Directory": "저장 경로로부터 여러장 처리", "Batch from Directory": "저장 경로로부터 여러장 처리",
"Batch img2img": "이미지→이미지 배치", "Batch img2img": "이미지→이미지 배치",
...@@ -66,12 +68,14 @@ ...@@ -66,12 +68,14 @@
"Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows": "서로 다른 설정값으로 생성된 이미지의 그리드를 만듭니다. 아래의 설정으로 가로/세로에 어떤 설정값을 적용할지 선택하세요.", "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows": "서로 다른 설정값으로 생성된 이미지의 그리드를 만듭니다. 아래의 설정으로 가로/세로에 어떤 설정값을 적용할지 선택하세요.",
"Create a text file next to every image with generation parameters.": "생성된 이미지마다 생성 설정값을 담은 텍스트 파일 생성하기", "Create a text file next to every image with generation parameters.": "생성된 이미지마다 생성 설정값을 담은 텍스트 파일 생성하기",
"Create aesthetic images embedding": "스타일 이미지 임베딩 생성하기", "Create aesthetic images embedding": "스타일 이미지 임베딩 생성하기",
"Create debug image": "디버그 이미지 생성",
"Create embedding": "임베딩 생성", "Create embedding": "임베딩 생성",
"Create flipped copies": "좌우로 뒤집은 복사본 생성", "Create flipped copies": "좌우로 뒤집은 복사본 생성",
"Create hypernetwork": "하이퍼네트워크 생성", "Create hypernetwork": "하이퍼네트워크 생성",
"Create images embedding": "이미지 임베딩 생성하기", "Create images embedding": "이미지 임베딩 생성하기",
"Crop and resize": "잘라낸 후 리사이징", "Crop and resize": "잘라낸 후 리사이징",
"Crop to fit": "잘라내서 맞추기", "Crop to fit": "잘라내서 맞추기",
"custom fold": "커스텀 경로",
"Custom Name (Optional)": "병합 모델 이름 (선택사항)", "Custom Name (Optional)": "병합 모델 이름 (선택사항)",
"Dataset directory": "데이터셋 경로", "Dataset directory": "데이터셋 경로",
"DDIM": "DDIM", "DDIM": "DDIM",
...@@ -107,6 +111,7 @@ ...@@ -107,6 +111,7 @@
"Embedding": "임베딩", "Embedding": "임베딩",
"Embedding Learning rate": "임베딩 학습률", "Embedding Learning rate": "임베딩 학습률",
"Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention": "강조 : (텍스트)를 이용해 모델의 텍스트에 대한 가중치를 더 강하게 주고 [텍스트]를 이용해 더 약하게 줍니다.", "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention": "강조 : (텍스트)를 이용해 모델의 텍스트에 대한 가중치를 더 강하게 주고 [텍스트]를 이용해 더 약하게 줍니다.",
"Enable Autocomplete": "태그 자동완성 사용",
"Enable full page image viewer": "전체 페이지 이미지 뷰어 활성화", "Enable full page image viewer": "전체 페이지 이미지 뷰어 활성화",
"Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply.": "더 예리하고 깔끔한 결과물을 위해 K 샘플러들에 양자화를 적용합니다. 존재하는 시드가 변경될 수 있습니다. 재시작이 필요합니다.", "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply.": "더 예리하고 깔끔한 결과물을 위해 K 샘플러들에 양자화를 적용합니다. 존재하는 시드가 변경될 수 있습니다. 재시작이 필요합니다.",
"End Page": "마지막 페이지", "End Page": "마지막 페이지",
...@@ -145,6 +150,9 @@ ...@@ -145,6 +150,9 @@
"First Page": "처음 페이지", "First Page": "처음 페이지",
"Firstpass height": "초기 세로길이", "Firstpass height": "초기 세로길이",
"Firstpass width": "초기 가로길이", "Firstpass width": "초기 가로길이",
"Focal point edges weight": "경계면 가중치",
"Focal point entropy weight": "엔트로피 가중치",
"Focal point face weight": "얼굴 가중치",
"Font for image grids that have text": "텍스트가 존재하는 그리드 이미지의 폰트", "Font for image grids that have text": "텍스트가 존재하는 그리드 이미지의 폰트",
"for detailed explanation.": "를 참조하십시오.", "for detailed explanation.": "를 참조하십시오.",
"For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.": "SD 업스케일링에서 타일 간 몇 픽셀을 겹치게 할지 결정하는 설정값입니다. 타일들이 다시 한 이미지로 합쳐질 때, 눈에 띄는 이음매가 없도록 서로 겹치게 됩니다.", "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.": "SD 업스케일링에서 타일 간 몇 픽셀을 겹치게 할지 결정하는 설정값입니다. 타일들이 다시 한 이미지로 합쳐질 때, 눈에 띄는 이음매가 없도록 서로 겹치게 됩니다.",
...@@ -195,6 +203,7 @@ ...@@ -195,6 +203,7 @@
"Inpaint masked": "마스크만 처리", "Inpaint masked": "마스크만 처리",
"Inpaint not masked": "마스크 이외만 처리", "Inpaint not masked": "마스크 이외만 처리",
"Input directory": "인풋 이미지 경로", "Input directory": "인풋 이미지 경로",
"Input images directory": "이미지 경로 입력",
"Interpolation Method": "보간 방법", "Interpolation Method": "보간 방법",
"Interrogate\nCLIP": "CLIP\n분석", "Interrogate\nCLIP": "CLIP\n분석",
"Interrogate\nDeepBooru": "DeepBooru\n분석", "Interrogate\nDeepBooru": "DeepBooru\n분석",
...@@ -258,10 +267,12 @@ ...@@ -258,10 +267,12 @@
"None": "없음", "None": "없음",
"Nothing": "없음", "Nothing": "없음",
"Nothing found in the image.": "Nothing found in the image.", "Nothing found in the image.": "Nothing found in the image.",
"Number of columns on the page": "각 페이지마다 표시할 가로줄 수",
"Number of grids in each row": "각 세로줄마다 표시될 그리드 수", "Number of grids in each row": "각 세로줄마다 표시될 그리드 수",
"number of images to delete consecutively next": "연속적으로 삭제할 이미지 수", "number of images to delete consecutively next": "연속적으로 삭제할 이미지 수",
"Number of pictures displayed on each page": "각 페이지에 표시될 이미지 수", "Number of pictures displayed on each page": "각 페이지에 표시될 이미지 수",
"Number of repeats for a single input image per epoch; used only for displaying epoch number": "세대(Epoch)당 단일 인풋 이미지의 반복 횟수 - 세대(Epoch) 숫자를 표시하는 데에만 사용됩니다. ", "Number of repeats for a single input image per epoch; used only for displaying epoch number": "세대(Epoch)당 단일 인풋 이미지의 반복 횟수 - 세대(Epoch) 숫자를 표시하는 데에만 사용됩니다. ",
"Number of rows on the page": "각 페이지마다 표시할 세로줄 수",
"Number of vectors per token": "토큰별 벡터 수", "Number of vectors per token": "토큰별 벡터 수",
"Open for Clip Aesthetic!": "클립 스타일 기능을 활성화하려면 클릭!", "Open for Clip Aesthetic!": "클립 스타일 기능을 활성화하려면 클릭!",
"Open images output directory": "이미지 저장 경로 열기", "Open images output directory": "이미지 저장 경로 열기",
...@@ -375,6 +386,7 @@ ...@@ -375,6 +386,7 @@
"Seed": "시드", "Seed": "시드",
"Seed of a different picture to be mixed into the generation.": "결과물에 섞일 다른 그림의 시드", "Seed of a different picture to be mixed into the generation.": "결과물에 섞일 다른 그림의 시드",
"Select activation function of hypernetwork": "하이퍼네트워크 활성화 함수 선택", "Select activation function of hypernetwork": "하이퍼네트워크 활성화 함수 선택",
"Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended": "레이어 가중치 초기화 방식 선택 - relu류 : Kaiming 추천, sigmoid류 : Xavier 추천",
"Select which Real-ESRGAN models to show in the web UI. (Requires restart)": "WebUI에 표시할 Real-ESRGAN 모델을 선택하십시오. (재시작 필요)", "Select which Real-ESRGAN models to show in the web UI. (Requires restart)": "WebUI에 표시할 Real-ESRGAN 모델을 선택하십시오. (재시작 필요)",
"Send to extras": "부가기능으로 전송", "Send to extras": "부가기능으로 전송",
"Send to img2img": "이미지→이미지로 전송", "Send to img2img": "이미지→이미지로 전송",
...@@ -465,10 +477,11 @@ ...@@ -465,10 +477,11 @@
"Use BLIP for caption": "캡션에 BLIP 사용", "Use BLIP for caption": "캡션에 BLIP 사용",
"Use deepbooru for caption": "캡션에 deepbooru 사용", "Use deepbooru for caption": "캡션에 deepbooru 사용",
"Use dropout": "드롭아웃 사용", "Use dropout": "드롭아웃 사용",
"Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.": "다음 태그들을 사용해 이미지 파일명 형식을 결정하세요 : [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]. 비워두면 기본값으로 설정됩니다.", "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.": "다음 태그들을 사용해 이미지 파일명 형식을 결정하세요 : [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]. 비워두면 기본값으로 설정됩니다.",
"Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.": "다음 태그들을 사용해 이미지와 그리드의 하위 디렉토리명의 형식을 결정하세요 : [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]. 비워두면 기본값으로 설정됩니다.", "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.": "다음 태그들을 사용해 이미지와 그리드의 하위 디렉토리명의 형식을 결정하세요 : [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]. 비워두면 기본값으로 설정됩니다.",
"Use old emphasis implementation. Can be useful to reproduce old seeds.": "옛 방식의 강조 구현을 사용합니다. 옛 시드를 재현하는 데 효과적일 수 있습니다.", "Use old emphasis implementation. Can be useful to reproduce old seeds.": "옛 방식의 강조 구현을 사용합니다. 옛 시드를 재현하는 데 효과적일 수 있습니다.",
"Use original name for output filename during batch process in extras tab": "부가기능 탭에서 이미지를 여러장 처리 시 결과물 파일명에 기존 파일명 사용하기", "Use original name for output filename during batch process in extras tab": "부가기능 탭에서 이미지를 여러장 처리 시 결과물 파일명에 기존 파일명 사용하기",
"Use same seed for each image": "각 이미지에 동일한 시드 사용",
"use spaces for tags in deepbooru": "deepbooru에서 태그에 공백 사용", "use spaces for tags in deepbooru": "deepbooru에서 태그에 공백 사용",
"User interface": "사용자 인터페이스", "User interface": "사용자 인터페이스",
"Var. seed": "바리에이션 시드", "Var. seed": "바리에이션 시드",
...@@ -485,6 +498,7 @@ ...@@ -485,6 +498,7 @@
"Which algorithm to use to produce the image": "이미지를 생성할 때 사용할 알고리즘", "Which algorithm to use to produce the image": "이미지를 생성할 때 사용할 알고리즘",
"Width": "가로", "Width": "가로",
"wiki": " 위키", "wiki": " 위키",
"Wildcards": "와일드카드",
"Will upscale the image to twice the dimensions; use width and height sliders to set tile size": "이미지를 설정된 사이즈의 2배로 업스케일합니다. 상단의 가로와 세로 슬라이더를 이용해 타일 사이즈를 지정하세요.", "Will upscale the image to twice the dimensions; use width and height sliders to set tile size": "이미지를 설정된 사이즈의 2배로 업스케일합니다. 상단의 가로와 세로 슬라이더를 이용해 타일 사이즈를 지정하세요.",
"With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising).": "이미지→이미지 진행 시, 슬라이더로 설정한 스텝 수를 정확히 실행하기 (일반적으로 디노이즈 강도가 낮을수록 실제 설정된 스텝 수보다 적게 진행됨)", "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising).": "이미지→이미지 진행 시, 슬라이더로 설정한 스텝 수를 정확히 실행하기 (일반적으로 디노이즈 강도가 낮을수록 실제 설정된 스텝 수보다 적게 진행됨)",
"Write image to a directory (default - log/images) and generation parameters into csv file.": "이미지를 경로에 저장하고, 설정값들을 csv 파일로 저장합니다. (기본 경로 - log/images)", "Write image to a directory (default - log/images) and generation parameters into csv file.": "이미지를 경로에 저장하고, 설정값들을 csv 파일로 저장합니다. (기본 경로 - log/images)",
......
此差异已折叠。
此差异已折叠。
...@@ -7,6 +7,7 @@ import uvicorn ...@@ -7,6 +7,7 @@ import uvicorn
from fastapi import Body, APIRouter, HTTPException from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json from pydantic import BaseModel, Field, Json
from typing import List
import json import json
import io import io
import base64 import base64
...@@ -15,12 +16,12 @@ from PIL import Image ...@@ -15,12 +16,12 @@ from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json parameters: Json
info: Json info: Json
class ImageToImageResponse(BaseModel): class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json parameters: Json
info: Json info: Json
...@@ -65,7 +66,7 @@ class Api: ...@@ -65,7 +66,7 @@ class Api:
i.save(buffer, format="png") i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue())) b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
...@@ -111,7 +112,11 @@ class Api: ...@@ -111,7 +112,11 @@ class Api:
i.save(buffer, format="png") i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue())) b64images.append(base64.b64encode(buffer.getvalue()))
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
def extrasapi(self): def extrasapi(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -31,6 +31,7 @@ class ModelDef(BaseModel): ...@@ -31,6 +31,7 @@ class ModelDef(BaseModel):
field_alias: str field_alias: str
field_type: Any field_type: Any
field_value: Any field_value: Any
field_exclude: bool = False
class PydanticModelGenerator: class PydanticModelGenerator:
...@@ -78,7 +79,8 @@ class PydanticModelGenerator: ...@@ -78,7 +79,8 @@ class PydanticModelGenerator:
field=underscore(fields["key"]), field=underscore(fields["key"]),
field_alias=fields["key"], field_alias=fields["key"],
field_type=fields["type"], field_type=fields["type"],
field_value=fields["default"])) field_value=fields["default"],
field_exclude=fields["exclude"] if "exclude" in fields else False))
def generate_model(self): def generate_model(self):
""" """
...@@ -86,7 +88,7 @@ class PydanticModelGenerator: ...@@ -86,7 +88,7 @@ class PydanticModelGenerator:
from the json and overrides provided at initialization from the json and overrides provided at initialization
""" """
fields = { fields = {
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
} }
DynamicModel = create_model(self._model_name, **fields) DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True DynamicModel.__config__.allow_population_by_field_name = True
...@@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( ...@@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img", "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img, StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model() ).generate_model()
\ No newline at end of file
...@@ -5,6 +5,7 @@ import html ...@@ -5,6 +5,7 @@ import html
import os import os
import sys import sys
import traceback import traceback
import inspect
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
import torch import torch
...@@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared ...@@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
from collections import defaultdict, deque from collections import defaultdict, deque
from statistics import stdev, mean from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
activation_dict = { activation_dict = {
...@@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module): ...@@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module):
"leakyrelu": torch.nn.LeakyReLU, "leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU, "elu": torch.nn.ELU,
"swish": torch.nn.Hardswish, "swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
} }
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
...@@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module): ...@@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module):
else: else:
for layer in self.linear: for layer in self.linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer.weight.data.normal_(mean=0.0, std=0.01) w, b = layer.weight.data, layer.bias.data
layer.bias.data.zero_() if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
normal_(w, mean=0.0, std=0.01)
normal_(b, mean=0.0, std=0.005)
elif weight_init == 'XavierUniform':
xavier_uniform_(w)
zeros_(b)
elif weight_init == 'XavierNormal':
xavier_normal_(w)
zeros_(b)
elif weight_init == 'KaimingUniform':
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
zeros_(b)
elif weight_init == 'KaimingNormal':
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
zeros_(b)
else:
raise KeyError(f"Key {weight_init} is not defined as initialization!")
self.to(devices.device) self.to(devices.device)
def fix_old_state_dict(self, state_dict): def fix_old_state_dict(self, state_dict):
...@@ -105,7 +126,7 @@ class Hypernetwork: ...@@ -105,7 +126,7 @@ class Hypernetwork:
filename = None filename = None
name = None name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
self.filename = None self.filename = None
self.name = name self.name = name
self.layers = {} self.layers = {}
...@@ -114,13 +135,14 @@ class Hypernetwork: ...@@ -114,13 +135,14 @@ class Hypernetwork:
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.layer_structure = layer_structure self.layer_structure = layer_structure
self.activation_func = activation_func self.activation_func = activation_func
self.weight_init = weight_init
self.add_layer_norm = add_layer_norm self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout self.use_dropout = use_dropout
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
) )
def weights(self): def weights(self):
...@@ -144,6 +166,7 @@ class Hypernetwork: ...@@ -144,6 +166,7 @@ class Hypernetwork:
state_dict['layer_structure'] = self.layer_structure state_dict['layer_structure'] = self.layer_structure
state_dict['activation_func'] = self.activation_func state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['weight_initialization'] = self.weight_init
state_dict['use_dropout'] = self.use_dropout state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
...@@ -158,15 +181,21 @@ class Hypernetwork: ...@@ -158,15 +181,21 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu') state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
print(self.layer_structure)
self.activation_func = state_dict.get('activation_func', None) self.activation_func = state_dict.get('activation_func', None)
print(f"Activation function is {self.activation_func}")
self.weight_init = state_dict.get('weight_initialization', 'Normal')
print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False) self.add_layer_norm = state_dict.get('is_layer_norm', False)
print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False) self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" )
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
) )
self.name = state_dict.get('name', self.name) self.name = state_dict.get('name', self.name)
...@@ -458,7 +487,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -458,7 +487,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None: if image is not None:
shared.state.current_image = image shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
......
...@@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion ...@@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name. # Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- ")) name = "".join( x for x in name if (x.isalnum() or x in "._- "))
...@@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, ...@@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
enable_sizes=[int(x) for x in enable_sizes], enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure, layer_structure=layer_structure,
activation_func=activation_func, activation_func=activation_func,
weight_init=weight_init,
add_layer_norm=add_layer_norm, add_layer_norm=add_layer_norm,
use_dropout=use_dropout, use_dropout=use_dropout,
) )
......
...@@ -277,7 +277,7 @@ invalid_filename_chars = '<>:"/\\|?*\n' ...@@ -277,7 +277,7 @@ invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' ' invalid_filename_prefix = ' '
invalid_filename_postfix = ' .' invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)") re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$") re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128 max_filename_part_length = 128
...@@ -343,7 +343,7 @@ class FilenameGenerator: ...@@ -343,7 +343,7 @@ class FilenameGenerator:
def datetime(self, *args): def datetime(self, *args):
time_datetime = datetime.datetime.now() time_datetime = datetime.datetime.now()
time_format = args[0] if len(args) > 0 else self.default_time_format time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
try: try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
except pytz.exceptions.UnknownTimeZoneError as _: except pytz.exceptions.UnknownTimeZoneError as _:
...@@ -362,9 +362,9 @@ class FilenameGenerator: ...@@ -362,9 +362,9 @@ class FilenameGenerator:
for m in re_pattern.finditer(x): for m in re_pattern.finditer(x):
text, pattern = m.groups() text, pattern = m.groups()
res += text
if pattern is None: if pattern is None:
res += text
continue continue
pattern_args = [] pattern_args = []
...@@ -385,12 +385,9 @@ class FilenameGenerator: ...@@ -385,12 +385,9 @@ class FilenameGenerator:
print(f"Error adding [{pattern}] to filename", file=sys.stderr) print(f"Error adding [{pattern}] to filename", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
if replacement is None: if replacement is not None:
res += f'[{pattern}]'
else:
res += str(replacement) res += str(replacement)
continue
continue
res += f'[{pattern}]' res += f'[{pattern}]'
...@@ -454,17 +451,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -454,17 +451,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
""" """
namegen = FilenameGenerator(p, seed, prompt) namegen = FilenameGenerator(p, seed, prompt)
if extension == 'png' and opts.enable_pnginfo and info is not None:
pnginfo = PngImagePlugin.PngInfo()
if existing_info is not None:
for k, v in existing_info.items():
pnginfo.add_text(k, str(v))
pnginfo.add_text(pnginfo_section_name, info)
else:
pnginfo = None
if save_to_dirs is None: if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
...@@ -492,19 +478,27 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -492,19 +478,27 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if add_number: if add_number:
basecount = get_next_sequence_number(path, basename) basecount = get_next_sequence_number(path, basename)
fullfn = None fullfn = None
fullfn_without_extension = None
for i in range(500): for i in range(500):
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
if not os.path.exists(fullfn): if not os.path.exists(fullfn):
break break
else: else:
fullfn = os.path.join(path, f"{file_decoration}.{extension}") fullfn = os.path.join(path, f"{file_decoration}.{extension}")
fullfn_without_extension = os.path.join(path, file_decoration)
else: else:
fullfn = os.path.join(path, f"{forced_filename}.{extension}") fullfn = os.path.join(path, f"{forced_filename}.{extension}")
fullfn_without_extension = os.path.join(path, forced_filename)
pnginfo = existing_info or {}
if info is not None:
pnginfo[pnginfo_section_name] = info
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
script_callbacks.before_image_saved_callback(params)
image = params.image
fullfn = params.filename
info = params.pnginfo.get(pnginfo_section_name, None)
fullfn_without_extension, extension = os.path.splitext(params.filename)
def exif_bytes(): def exif_bytes():
return piexif.dump({ return piexif.dump({
...@@ -513,12 +507,20 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -513,12 +507,20 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
}, },
}) })
if extension.lower() in ("jpg", "jpeg", "webp"): if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo()
for k, v in params.pnginfo.items():
pnginfo_data.add_text(k, str(v))
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
image.save(fullfn, quality=opts.jpeg_quality) image.save(fullfn, quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None: if opts.enable_pnginfo and info is not None:
piexif.insert(exif_bytes(), fullfn) piexif.insert(exif_bytes(), fullfn)
else: else:
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo) image.save(fullfn, quality=opts.jpeg_quality)
target_side_length = 4000 target_side_length = 4000
oversize = image.width > target_side_length or image.height > target_side_length oversize = image.width > target_side_length or image.height > target_side_length
...@@ -541,7 +543,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -541,7 +543,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else: else:
txt_fullfn = None txt_fullfn = None
script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn) script_callbacks.image_saved_callback(params)
return fullfn, txt_fullfn return fullfn, txt_fullfn
......
...@@ -39,6 +39,8 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -39,6 +39,8 @@ def process_batch(p, input_dir, output_dir, args):
break break
img = Image.open(image) img = Image.open(image)
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size p.init_images = [img] * p.batch_size
proc = modules.scripts.scripts_img2img.run(p, *args) proc = modules.scripts.scripts_img2img.run(p, *args)
...@@ -61,19 +63,25 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -61,19 +63,25 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
is_batch = mode == 2 is_batch = mode == 2
if is_inpaint: if is_inpaint:
# Drawn mask
if mask_mode == 0: if mask_mode == 0:
image = init_img_with_mask['image'] image = init_img_with_mask['image']
mask = init_img_with_mask['mask'] mask = init_img_with_mask['mask']
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
image = image.convert('RGB') image = image.convert('RGB')
# Uploaded mask
else: else:
image = init_img_inpaint image = init_img_inpaint
mask = init_mask_inpaint mask = init_mask_inpaint
# No mask
else: else:
image = init_img image = init_img
mask = None mask = None
# Use the EXIF orientation of photos taken by smartphones.
image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img( p = StableDiffusionProcessingImg2Img(
......
...@@ -77,9 +77,8 @@ def get_correct_sampler(p): ...@@ -77,9 +77,8 @@ def get_correct_sampler(p):
class StableDiffusionProcessing(): class StableDiffusionProcessing():
""" """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
""" """
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
...@@ -109,13 +108,14 @@ class StableDiffusionProcessing(): ...@@ -109,13 +108,14 @@ class StableDiffusionProcessing():
self.do_not_reload_embeddings = do_not_reload_embeddings self.do_not_reload_embeddings = do_not_reload_embeddings
self.paste_to = None self.paste_to = None
self.color_corrections = None self.color_corrections = None
self.denoising_strength: float = 0 self.denoising_strength: float = denoising_strength
self.sampler_noise_scheduler_override = None self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize self.ddim_discretize = ddim_discretize or opts.ddim_discretize
self.s_churn = s_churn or opts.s_churn self.s_churn = s_churn or opts.s_churn
self.s_tmin = s_tmin or opts.s_tmin self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise self.s_noise = s_noise or opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
if not seed_enable_extras: if not seed_enable_extras:
self.subseed = -1 self.subseed = -1
...@@ -129,7 +129,6 @@ class StableDiffusionProcessing(): ...@@ -129,7 +129,6 @@ class StableDiffusionProcessing():
self.all_seeds = None self.all_seeds = None
self.all_subseeds = None self.all_subseeds = None
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
pass pass
...@@ -351,6 +350,22 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -351,6 +350,22 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
def process_images(p: StableDiffusionProcessing) -> Processed: def process_images(p: StableDiffusionProcessing) -> Processed:
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
for k, v in p.override_settings.items():
opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
res = process_images_inner(p)
finally:
for k, v in stored_opts.items():
opts.data[k] = v
return res
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if type(p.prompt) == list: if type(p.prompt) == list:
......
...@@ -9,15 +9,34 @@ def report_exception(c, job): ...@@ -9,15 +9,34 @@ def report_exception(c, job):
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
class ImageSaveParams:
def __init__(self, image, p, filename, pnginfo):
self.image = image
"""the PIL image itself"""
self.p = p
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
self.filename = filename
"""name of file that the image would be saved to"""
self.pnginfo = pnginfo
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_model_loaded = [] callbacks_model_loaded = []
callbacks_ui_tabs = [] callbacks_ui_tabs = []
callbacks_ui_settings = [] callbacks_ui_settings = []
callbacks_before_image_saved = []
callbacks_image_saved = [] callbacks_image_saved = []
def clear_callbacks(): def clear_callbacks():
callbacks_model_loaded.clear() callbacks_model_loaded.clear()
callbacks_ui_tabs.clear() callbacks_ui_tabs.clear()
callbacks_ui_settings.clear()
callbacks_before_image_saved.clear()
callbacks_image_saved.clear() callbacks_image_saved.clear()
...@@ -49,10 +68,18 @@ def ui_settings_callback(): ...@@ -49,10 +68,18 @@ def ui_settings_callback():
report_exception(c, 'ui_settings_callback') report_exception(c, 'ui_settings_callback')
def image_saved_callback(image, p, fullfn, txt_fullfn): def before_image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved: for c in callbacks_image_saved:
try: try:
c.callback(image, p, fullfn, txt_fullfn) c.callback(params)
except Exception:
report_exception(c, 'before_image_saved_callback')
def image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved:
try:
c.callback(params)
except Exception: except Exception:
report_exception(c, 'image_saved_callback') report_exception(c, 'image_saved_callback')
...@@ -64,7 +91,6 @@ def add_callback(callbacks, fun): ...@@ -64,7 +91,6 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun)) callbacks.append(ScriptCallback(filename, fun))
def on_model_loaded(callback): def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is """register a function to be called when the stable diffusion model is created; the model is
passed as an argument""" passed as an argument"""
...@@ -90,11 +116,17 @@ def on_ui_settings(callback): ...@@ -90,11 +116,17 @@ def on_ui_settings(callback):
add_callback(callbacks_ui_settings, callback) add_callback(callbacks_ui_settings, callback)
def on_save_imaged(callback): def on_before_image_saved(callback):
"""register a function to be called after modules.images.save_image is called. """register a function to be called before an image is saved to a file.
The callback is called with three arguments: The callback is called with one argument:
- p - procesing object (or a dummy object with same fields if the image is saved using save button) - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
- fullfn - image filename """
- txt_fullfn - text file with parameters; may be None add_callback(callbacks_before_image_saved, callback)
def on_image_saved(callback):
"""register a function to be called after an image is saved to a file.
The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
""" """
add_callback(callbacks_image_saved, callback) add_callback(callbacks_image_saved, callback)
...@@ -84,7 +84,7 @@ parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load mod ...@@ -84,7 +84,7 @@ parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load mod
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
restricted_opts = [ restricted_opts = {
"samples_filename_pattern", "samples_filename_pattern",
"directories_filename_pattern", "directories_filename_pattern",
"outdir_samples", "outdir_samples",
...@@ -94,7 +94,7 @@ restricted_opts = [ ...@@ -94,7 +94,7 @@ restricted_opts = [
"outdir_grids", "outdir_grids",
"outdir_txt2img_grids", "outdir_txt2img_grids",
"outdir_save", "outdir_save",
] }
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
......
import cv2
import requests
import os
from collections import defaultdict
from math import log, sqrt
import numpy as np
from PIL import Image, ImageDraw
GREEN = "#0F0"
BLUE = "#00F"
RED = "#F00"
def crop_image(im, settings):
""" Intelligently crop an image to the subject matter """
scale_by = 1
if is_landscape(im.width, im.height):
scale_by = settings.crop_height / im.height
elif is_portrait(im.width, im.height):
scale_by = settings.crop_width / im.width
elif is_square(im.width, im.height):
if is_square(settings.crop_width, settings.crop_height):
scale_by = settings.crop_width / im.width
elif is_landscape(settings.crop_width, settings.crop_height):
scale_by = settings.crop_width / im.width
elif is_portrait(settings.crop_width, settings.crop_height):
scale_by = settings.crop_height / im.height
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
im_debug = im.copy()
focus = focal_point(im_debug, settings)
# take the focal point and turn it into crop coordinates that try to center over the focal
# point but then get adjusted back into the frame
y_half = int(settings.crop_height / 2)
x_half = int(settings.crop_width / 2)
x1 = focus.x - x_half
if x1 < 0:
x1 = 0
elif x1 + settings.crop_width > im.width:
x1 = im.width - settings.crop_width
y1 = focus.y - y_half
if y1 < 0:
y1 = 0
elif y1 + settings.crop_height > im.height:
y1 = im.height - settings.crop_height
x2 = x1 + settings.crop_width
y2 = y1 + settings.crop_height
crop = [x1, y1, x2, y2]
results = []
results.append(im.crop(tuple(crop)))
if settings.annotate_image:
d = ImageDraw.Draw(im_debug)
rect = list(crop)
rect[2] -= 1
rect[3] -= 1
d.rectangle(rect, outline=GREEN)
results.append(im_debug)
if settings.destop_view_image:
im_debug.show()
return results
def focal_point(im, settings):
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
pois = []
weight_pref_total = 0
if len(corner_points) > 0:
weight_pref_total += settings.corner_points_weight
if len(entropy_points) > 0:
weight_pref_total += settings.entropy_points_weight
if len(face_points) > 0:
weight_pref_total += settings.face_points_weight
corner_centroid = None
if len(corner_points) > 0:
corner_centroid = centroid(corner_points)
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
pois.append(corner_centroid)
entropy_centroid = None
if len(entropy_points) > 0:
entropy_centroid = centroid(entropy_points)
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
pois.append(entropy_centroid)
face_centroid = None
if len(face_points) > 0:
face_centroid = centroid(face_points)
face_centroid.weight = settings.face_points_weight / weight_pref_total
pois.append(face_centroid)
average_point = poi_average(pois, settings)
if settings.annotate_image:
d = ImageDraw.Draw(im)
max_size = min(im.width, im.height) * 0.07
if corner_centroid is not None:
color = BLUE
box = corner_centroid.bounding(max_size * corner_centroid.weight)
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
d.ellipse(box, outline=color)
if len(corner_points) > 1:
for f in corner_points:
d.rectangle(f.bounding(4), outline=color)
if entropy_centroid is not None:
color = "#ff0"
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
d.ellipse(box, outline=color)
if len(entropy_points) > 1:
for f in entropy_points:
d.rectangle(f.bounding(4), outline=color)
if face_centroid is not None:
color = RED
box = face_centroid.bounding(max_size * face_centroid.weight)
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
d.ellipse(box, outline=color)
if len(face_points) > 1:
for f in face_points:
d.rectangle(f.bounding(4), outline=color)
d.ellipse(average_point.bounding(max_size), outline=GREEN)
return average_point
def image_face_points(im, settings):
if settings.dnn_model_path is not None:
detector = cv2.FaceDetectorYN.create(
settings.dnn_model_path,
"",
(im.width, im.height),
0.9, # score threshold
0.3, # nms threshold
5000 # keep top k before nms
)
faces = detector.detect(np.array(im))
results = []
if faces[1] is not None:
for face in faces[1]:
x = face[0]
y = face[1]
w = face[2]
h = face[3]
results.append(
PointOfInterest(
int(x + (w * 0.5)), # face focus left/right is center
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
size = w,
weight = 1/len(faces[1])
)
)
return results
else:
np_im = np.array(im)
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
tries = [
[ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
[ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
]
for t in tries:
classifier = cv2.CascadeClassifier(t[0])
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
try:
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
except:
continue
if len(faces) > 0:
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
return []
def image_corner_points(im, settings):
grayscale = im.convert("L")
# naive attempt at preventing focal points from collecting at watermarks near the bottom
gd = ImageDraw.Draw(grayscale)
gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
np_im = np.array(grayscale)
points = cv2.goodFeaturesToTrack(
np_im,
maxCorners=100,
qualityLevel=0.04,
minDistance=min(grayscale.width, grayscale.height)*0.06,
useHarrisDetector=False,
)
if points is None:
return []
focal_points = []
for point in points:
x, y = point.ravel()
focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
return focal_points
def image_entropy_points(im, settings):
landscape = im.height < im.width
portrait = im.height > im.width
if landscape:
move_idx = [0, 2]
move_max = im.size[0]
elif portrait:
move_idx = [1, 3]
move_max = im.size[1]
else:
return []
e_max = 0
crop_current = [0, 0, settings.crop_width, settings.crop_height]
crop_best = crop_current
while crop_current[move_idx[1]] < move_max:
crop = im.crop(tuple(crop_current))
e = image_entropy(crop)
if (e > e_max):
e_max = e
crop_best = list(crop_current)
crop_current[move_idx[0]] += 4
crop_current[move_idx[1]] += 4
x_mid = int(crop_best[0] + settings.crop_width/2)
y_mid = int(crop_best[1] + settings.crop_height/2)
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
def image_entropy(im):
# greyscale image entropy
# band = np.asarray(im.convert("L"))
band = np.asarray(im.convert("1"), dtype=np.uint8)
hist, _ = np.histogram(band, bins=range(0, 256))
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
def centroid(pois):
x = [poi.x for poi in pois]
y = [poi.y for poi in pois]
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
def poi_average(pois, settings):
weight = 0.0
x = 0.0
y = 0.0
for poi in pois:
weight += poi.weight
x += poi.x * poi.weight
y += poi.y * poi.weight
avg_x = round(x / weight)
avg_y = round(y / weight)
return PointOfInterest(avg_x, avg_y)
def is_landscape(w, h):
return w > h
def is_portrait(w, h):
return h > w
def is_square(w, h):
return w == h
def download_and_cache_models(dirname):
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
model_file_name = 'face_detection_yunet.onnx'
if not os.path.exists(dirname):
os.makedirs(dirname)
cache_file = os.path.join(dirname, model_file_name)
if not os.path.exists(cache_file):
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
response = requests.get(download_url)
with open(cache_file, "wb") as f:
f.write(response.content)
if os.path.exists(cache_file):
return cache_file
return None
class PointOfInterest:
def __init__(self, x, y, weight=1.0, size=10):
self.x = x
self.y = y
self.weight = weight
self.size = size
def bounding(self, size):
return [
self.x - size//2,
self.y - size//2,
self.x + size//2,
self.y + size//2
]
class Settings:
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
self.crop_width = crop_width
self.crop_height = crop_height
self.corner_points_weight = corner_points_weight
self.entropy_points_weight = entropy_points_weight
self.face_points_weight = face_points_weight
self.annotate_image = annotate_image
self.destop_view_image = False
self.dnn_model_path = dnn_model_path
\ No newline at end of file
...@@ -7,12 +7,14 @@ import tqdm ...@@ -7,12 +7,14 @@ import tqdm
import time import time
from modules import shared, images from modules import shared, images
from modules.paths import models_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru: if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
try: try:
if process_caption: if process_caption:
shared.interrogator.load() shared.interrogator.load()
...@@ -22,7 +24,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce ...@@ -22,7 +24,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio) preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
finally: finally:
...@@ -34,7 +36,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce ...@@ -34,7 +36,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width width = process_width
height = process_height height = process_height
src = os.path.abspath(process_src) src = os.path.abspath(process_src)
...@@ -113,6 +115,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -113,6 +115,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
splitted = image.crop((0, y, to_w, y + to_h)) splitted = image.crop((0, y, to_w, y + to_h))
yield splitted yield splitted
for index, imagefile in enumerate(tqdm.tqdm(files)): for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0] subindex = [0]
filename = os.path.join(src, imagefile) filename = os.path.join(src, imagefile)
...@@ -137,11 +140,36 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -137,11 +140,36 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
ratio = (img.height * width) / (img.width * height) ratio = (img.height * width) / (img.width * height)
inverse_xy = True inverse_xy = True
process_default_resize = True
if process_split and ratio < 1.0 and ratio <= split_threshold: if process_split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy): for splitted in split_pic(img, inverse_xy):
save_pic(splitted, index, existing_caption=existing_caption) save_pic(splitted, index, existing_caption=existing_caption)
else: process_default_resize = False
if process_focal_crop and img.height != img.width:
dnn_model_path = None
try:
dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
except Exception as e:
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
autocrop_settings = autocrop.Settings(
crop_width = width,
crop_height = height,
face_points_weight = process_focal_crop_face_weight,
entropy_points_weight = process_focal_crop_entropy_weight,
corner_points_weight = process_focal_crop_edges_weight,
annotate_image = process_focal_crop_debug,
dnn_model_path = dnn_model_path,
)
for focal in autocrop.crop_image(img, autocrop_settings):
save_pic(focal, index, existing_caption=existing_caption)
process_default_resize = False
if process_default_resize:
img = images.resize_image(1, img, width, height) img = images.resize_image(1, img, width, height)
save_pic(img, index, existing_caption=existing_caption) save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob() shared.state.nextjob()
\ No newline at end of file
...@@ -10,7 +10,7 @@ import csv ...@@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models, images
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
...@@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): ...@@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
...@@ -164,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): ...@@ -164,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
for i in range(num_vectors_per_token): for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
if not overwrite_old: if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists" assert not os.path.exists(fn), f"file {fn} already exists"
...@@ -244,6 +249,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -244,6 +249,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>"
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0 ititial_step = embedding.step or 0
...@@ -283,7 +289,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -283,7 +289,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}") pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') # Before saving, change name to match current checkpoint.
embedding.name = f'{embedding_name}-{embedding.step}'
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
embedding.save(last_saved_file) embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True embedding_yet_to_be_embedded = True
...@@ -293,8 +301,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -293,8 +301,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
}) })
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') forced_filename = f'{embedding_name}-{embedding.step}'
last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
do_not_save_grid=True, do_not_save_grid=True,
...@@ -350,8 +358,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -350,8 +358,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
image.save(last_saved_image) last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step shared.state.job_no = embedding.step
...@@ -371,6 +378,9 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -371,6 +378,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding.sd_checkpoint = checkpoint.hash embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None embedding.cached_checksum = None
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
embedding.name = embedding_name
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
embedding.save(filename) embedding.save(filename)
return embedding, filename return embedding, filename
...@@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"]) new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
...@@ -1260,6 +1261,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1260,6 +1261,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row(): with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies') process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images') process_split = gr.Checkbox(label='Split oversized images')
process_focal_crop = gr.Checkbox(label='Auto focal point crop')
process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
...@@ -1267,6 +1269,12 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1267,6 +1269,12 @@ def create_ui(wrap_gradio_gpu_call):
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
with gr.Row(visible=False) as process_focal_crop_row:
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_debug = gr.Checkbox(label='Create debug image')
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
gr.HTML(value="") gr.HTML(value="")
...@@ -1280,6 +1288,12 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1280,6 +1288,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[process_split_extra_row], outputs=[process_split_extra_row],
) )
process_focal_crop.change(
fn=lambda show: gr_show(show),
inputs=[process_focal_crop],
outputs=[process_focal_crop_row],
)
with gr.Tab(label="Train"): with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row(): with gr.Row():
...@@ -1342,6 +1356,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1342,6 +1356,7 @@ def create_ui(wrap_gradio_gpu_call):
overwrite_old_hypernetwork, overwrite_old_hypernetwork,
new_hypernetwork_layer_structure, new_hypernetwork_layer_structure,
new_hypernetwork_activation_func, new_hypernetwork_activation_func,
new_hypernetwork_initialization_option,
new_hypernetwork_add_layer_norm, new_hypernetwork_add_layer_norm,
new_hypernetwork_use_dropout new_hypernetwork_use_dropout
], ],
...@@ -1367,6 +1382,11 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1367,6 +1382,11 @@ def create_ui(wrap_gradio_gpu_call):
process_caption_deepbooru, process_caption_deepbooru,
process_split_threshold, process_split_threshold,
process_overlap_ratio, process_overlap_ratio,
process_focal_crop,
process_focal_crop_face_weight,
process_focal_crop_entropy_weight,
process_focal_crop_edges_weight,
process_focal_crop_debug,
], ],
outputs=[ outputs=[
ti_output, ti_output,
......
import copy import copy
import math import math
import os import os
import random
import sys import sys
import traceback import traceback
import shlex import shlex
...@@ -81,32 +82,34 @@ def cmdargs(line): ...@@ -81,32 +82,34 @@ def cmdargs(line):
return res return res
def load_prompt_file(file):
if (file is None):
lines = []
else:
lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
return None, "\n".join(lines), gr.update(lines=7)
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
return "Prompts from file or textbox" return "Prompts from file or textbox"
def ui(self, is_img2img): def ui(self, is_img2img):
# This checkbox would look nicer as two tabs, but there are two problems: checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False)
# 1) There is a bug in Gradio 3.3 that prevents visibility from working on Tabs
# 2) Even with Gradio 3.3.1, returning a control (like Tabs) that can't be used as input prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1)
# causes a AttributeError: 'Tabs' object has no attribute 'preprocess' assert, file = gr.File(label="Upload prompt inputs", type='bytes')
# due to the way Script assumes all controls returned can be used as inputs.
# Therefore, there's no good way to use grouping components right now, file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
# so we will use a checkbox! :)
checkbox_txt = gr.Checkbox(label="Show Textbox", value=False) # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
file = gr.File(label="File with inputs", type='bytes') # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
prompt_txt = gr.TextArea(label="Prompts") # be unclear to the user that shift-enter is needed.
checkbox_txt.change(fn=lambda x: [gr.File.update(visible = not x), gr.TextArea.update(visible = x)], inputs=[checkbox_txt], outputs=[file, prompt_txt]) prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
return [checkbox_txt, file, prompt_txt] return [checkbox_iterate, file, prompt_txt]
def on_show(self, checkbox_txt, file, prompt_txt): def run(self, p, checkbox_iterate, file, prompt_txt: str):
return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ] lines = [x.strip() for x in prompt_txt.splitlines()]
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
if checkbox_txt:
lines = [x.strip() for x in prompt_txt.splitlines()]
else:
lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")]
lines = [x for x in lines if len(x) > 0] lines = [x for x in lines if len(x) > 0]
p.do_not_save_grid = True p.do_not_save_grid = True
...@@ -134,6 +137,9 @@ class Script(scripts.Script): ...@@ -134,6 +137,9 @@ class Script(scripts.Script):
jobs.append(args) jobs.append(args)
print(f"Will process {len(lines)} lines in {job_count} jobs.") print(f"Will process {len(lines)} lines in {job_count} jobs.")
if (checkbox_iterate and p.seed == -1):
p.seed = int(random.randrange(4294967294))
state.job_count = job_count state.job_count = job_count
images = [] images = []
...@@ -146,5 +152,9 @@ class Script(scripts.Script): ...@@ -146,5 +152,9 @@ class Script(scripts.Script):
proc = process_images(copy_p) proc = process_images(copy_p)
images += proc.images images += proc.images
if (checkbox_iterate):
p.seed = p.seed + (p.batch_size * p.n_iter)
return Processed(p, images, p.seed, "") return Processed(p, images, p.seed, "")
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册