Fermer

novembre 26, 2024

Votre propre service API de génération d’images avec FLUX, Python et Diffusers —

Votre propre service API de génération d’images avec FLUX, Python et Diffusers —


FLUX (par Laboratoires de la Forêt-Noire) a pris d’assaut le monde de la génération d’images IA au cours des derniers mois. Non seulement il a battu Stable Diffusion (l’ancien roi de l’open source) sur de nombreux benchmarks, mais il a également surpassé les modèles propriétaires comme Dall-E ou Midjourney dans certains paramètres.

Mais comment procéderiez-vous pour utiliser FLUX sur l’une de vos applications ? On pourrait penser à utiliser des hôtes sans serveur comme Reproduire et d’autres, mais ceux-ci peuvent devenir très coûteux très rapidement et peuvent ne pas offrir la flexibilité dont vous avez besoin. C’est là que la création de votre propre serveur FLUX s’avère utile.

Dans cet article, nous vous guiderons dans la création de votre propre serveur FLUX à l’aide de Python. Ce serveur vous permettra de générer des images basées sur des invites textuelles via une simple API. Que vous utilisiez ce serveur pour un usage personnel ou que vous le déployiez dans le cadre d’une application de production, ce guide vous aidera à démarrer.

Conditions préalables

Avant de plonger dans le code, assurons-nous que vous disposez des outils et des bibliothèques nécessaires :

  • Python : Vous aurez besoin de Python 3 installé sur votre machine, de préférence la version 3.10.
  • torch: Le framework d’apprentissage profond que nous utiliserons pour exécuter FLUX.
  • diffusers: Donne accès au modèle FLUX.
  • transformers: Dépendance requise des diffuseurs.
  • sentencepiece: Nécessaire pour exécuter le tokenizer FLUX
  • protobuf: Nécessaire pour exécuter FLUX
  • accelerate: Aide à charger le modèle FLUX plus efficacement dans certains cas.
  • fastapi: Framework pour créer un serveur Web pouvant accepter les requêtes de génération d’images.
  • uvicorn: Nécessaire pour exécuter le serveur fastapi.
  • psutil: Nous permet de vérifier la quantité de RAM qu’il y a sur notre machine.

Vous pouvez installer toutes les bibliothèques en exécutant la commande suivante : pip install torch diffusers transformers accelerate fastapi uvicorn psutil.

Remarque pour les utilisateurs de MacOS : si vous utilisez un Mac avec une puce M1 ou M2, vous devez configurer PyTorch avec Metal pour des performances optimales. Suivez le guide officiel de PyTorch avec Metal avant de continuer.

Étape 1 : Configuration de l’environnement

Commençons le script en choisissant le bon périphérique pour exécuter l’inférence en fonction du matériel que nous utilisons.

import torch

device = 'cuda' # can also be 'cpu' or 'mps'

if device == 'mps' and not torch.backends.mps.is_available():
      raise Exception("Device set to MPS, but MPS is not available")
elif device == 'cuda' and not torch.cuda.is_available():
      raise Exception("Device set to CUDA, but CUDA is not available")

Vous pouvez préciser cpu, cuda (pour les GPU NVIDIA), ou mps (pour les Metal Performance Shaders d’Apple). Le script vérifie ensuite si le périphérique sélectionné est disponible et déclenche une exception si ce n’est pas le cas.

Étape 2 : Chargement du modèle FLUX

Ensuite, nous chargeons le modèle FLUX. Nous allons charger le modèle avec une précision fp16, ce qui nous fera économiser de la mémoire sans trop de perte de qualité.

Remarque : À ce stade, il vous sera peut-être demandé de vous authentifier auprès de HuggingFace, car le modèle FLUX est sécurisé. Pour vous authentifier avec succès, vous devrez créer un compte HuggingFace, accéder à la page du modèle, accepter les conditions, puis créer un jeton HuggingFace à partir des paramètres de votre compte et l’ajouter sur votre ordinateur en tant que variable d’environnement HF_TOKEN.

from diffusers import DDIMScheduler, FluxPipeline
import psutil

model_name = "black-forest-labs/FLUX.1-dev"

print(f"Loading {model_name} on {device}")

pipeline = FluxPipeline.from_pretrained(
      model_name,

      
      
      torch_dtype=torch.float16,

      
      use_safetensors=True 
).to(device)


pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)

Ici, nous chargeons le modèle FLUX à l’aide de la bibliothèque des diffuseurs. Le modèle que nous utilisons est black-forest-labs/FLUX.1-devchargé en précision fp16. Il existe également un modèle FLUX pro qui est plus puissant, mais malheureusement pas open source et ne peut donc pas être utilisé.

Nous utiliserons ici le planificateur DDIM, mais vous pouvez également en choisir un autre comme Euler ou UniPC. Vous pouvez en savoir plus sur les planificateurs ici.

Étant donné que la génération d’images peut nécessiter beaucoup de ressources, il est crucial d’optimiser l’utilisation de la mémoire, en particulier lorsqu’elle est exécutée sur un processeur ou un périphérique doté d’une mémoire limitée.

# Recommended if running on MPS or CPU with < 64 GB of RAM
total_memory = psutil.virtual_memory().total
total_memory_gb = total_memory / (1024 ** 3)
if (device == 'cpu' or device == 'mps') and total_memory_gb < 64:
      print("Enabling attention slicing")
      pipeline.enable_attention_slicing()

Ce code vérifie la mémoire totale disponible et active le découpage d’attention si le système dispose de moins de 64 Go de RAM. Le découpage de l’attention réduit l’utilisation de la mémoire lors de la génération d’images, ce qui est essentiel pour les appareils aux ressources limitées.

Étape 3 : Création de l’API avec FastAPI

Ensuite, nous allons configurer le serveur FastAPI, qui fournira une API pour générer des images.

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, conint, confloat
from fastapi.middleware.gzip import GZipMiddleware
from io import BytesIO
import base64

app = FastAPI()



app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=7)

FastAPI est un framework populaire pour créer des API Web avec Python. Dans ce cas, nous l’utilisons pour créer un serveur capable d’accepter les demandes de génération d’images. Nous utilisons également le middleware GZip pour compresser la réponse, ce qui est particulièrement utile lors du renvoi d’images au format base64.

Remarque : Dans un environnement de production, vous souhaiterez peut-être stocker les images générées dans un compartiment S3 ou un autre stockage cloud et renvoyer les URL au lieu des chaînes codées en base64, pour profiter d’un CDN et d’autres optimisations.

Étape 4 : Définition du modèle de demande

Nous devons définir un modèle pour les requêtes que notre API acceptera.

class GenerateRequest(BaseModel):
      prompt: str
      negative_prompt: str
      seed: conint(ge=0) = Field(..., description="Seed for random number generation")
      height: conint(gt=0) = Field(..., description="Height of the generated image, must be a positive integer and a multiple of 8")
      width: conint(gt=0) = Field(..., description="Width of the generated image, must be a positive integer and a multiple of 8")
      cfg: confloat(gt=0) = Field(..., description="CFG (classifier-free guidance scale), must be a positive integer or 0")
      steps: conint(ge=0) = Field(..., description="Number of steps")
      batch_size: conint(gt=0) = Field(..., description="Number of images to generate in a batch")

Ce modèle GenerateRequest définit les paramètres requis pour générer une image. L’invite est la description textuelle de l’image que vous souhaitez créer. Le négatif_prompt peut être utilisé pour spécifier ce que vous ne voulez pas dans l’image. D’autres champs incluent les dimensions de l’image, le nombre d’étapes d’inférence et la taille du lot.

Étape 5 : Création du point de terminaison de génération d’images

Créons maintenant le point de terminaison qui gérera les demandes de génération d’images.

@app.post("https://www.sitepoint.com/")
async def generate_image(request: GenerateRequest):
      
      
      if request.height % 8 != 0 or request.width % 8 != 0:
            raise HTTPException(status_code=400, detail="Height and width must both be multiples of 8")

      
      
      generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(request.seed, request.seed + request.batch_size)]

      images = pipeline(
            height=request.height,
            width=request.width,
            prompt=request.prompt,
            negative_prompt=request.negative_prompt,
            generator=generator,
            num_inference_steps=request.steps,
            guidance_scale=request.cfg,
            num_images_per_prompt=request.batch_size
      ).images

      
      
      
      base64_images = []
      for image in images:
            buffered = BytesIO()
            image.save(buffered, format="PNG")
            img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
            base64_images.append(img_str)

      return {
            "images": base64_images,
      }

Ce point de terminaison gère le processus de génération d’images. Il valide d’abord que la hauteur et la largeur sont des multiples de 8, comme l’exige FLUX. Il génère ensuite des images basées sur l’invite fournie et les renvoie sous forme de chaînes codées en base64.

Étape 6 : Démarrage du serveur

Enfin, ajoutons du code pour démarrer le serveur lorsque le script est exécuté.

@app.on_event("startup")
async def startup_event():
      print("Image generation server running")

if __name__ == "__main__":
      import uvicorn
      uvicorn.run(app, host="0.0.0.0", port=8000)

Ce code démarre le serveur FastAPI sur le port 8000, le rendant accessible depuis http://localhost:8000.

Étape 7 : tester votre serveur localement

Maintenant que votre serveur FLUX est opérationnel, il est temps de le tester. Vous pouvez utiliser curl, un outil de ligne de commande permettant d’effectuer des requêtes HTTP, pour interagir avec votre serveur :

curl -X POST "http://localhost:8000/" \
-H "Content-Type: application/json" \
-d '{
  "prompt": "A futuristic cityscape at sunset",
  "negative_prompt": "low quality, blurry",
  "seed": 42,
  "height": 512,
  "width": 512,
  "cfg": 7.5,
  "steps": 50,
  "batch_size": 1
}'

Conclusion

Félicitations! Vous avez créé avec succès votre propre serveur FLUX en utilisant Python. Cette configuration vous permet de générer des images basées sur des invites textuelles via une simple API. Si vous n’êtes pas satisfait des résultats du modèle de base FLUX, vous pourriez envisager d’affiner le modèle pour des performances encore meilleures ou des cas d’utilisation spécifiques.

Code complet

Vous pouvez trouver le code complet utilisé dans ce guide ci-dessous :

import torch

device = 'cuda' 

if device == 'mps' and not torch.backends.mps.is_available():
      raise Exception("Device set to MPS, but MPS is not available")
elif device == 'cuda' and not torch.cuda.is_available():
      raise Exception("Device set to CUDA, but CUDA is not available")


from diffusers import DDIMScheduler, FluxPipeline
import psutil

model_name = "black-forest-labs/FLUX.1-dev"

print(f"Loading {model_name} on {device}")

pipeline = FluxPipeline.from_pretrained(
      model_name,

      
      
      torch_dtype=torch.float16,

      
      use_safetensors=True 
).to(device)


pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)


total_memory = psutil.virtual_memory().total
total_memory_gb = total_memory / (1024 ** 3)
if (device == 'cpu' or device == 'mps') and total_memory_gb < 64:
      print("Enabling attention slicing")
      pipeline.enable_attention_slicing()


from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, conint, confloat
from fastapi.middleware.gzip import GZipMiddleware
from io import BytesIO
import base64

app = FastAPI()



app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=7)

class GenerateRequest(BaseModel):
      prompt: str
      negative_prompt: str
      seed: conint(ge=0) = Field(..., description="Seed for random number generation")
      height: conint(gt=0) = Field(..., description="Height of the generated image, must be a positive integer and a multiple of 8")
      width: conint(gt=0) = Field(..., description="Width of the generated image, must be a positive integer and a multiple of 8")
      cfg: confloat(gt=0) = Field(..., description="CFG (classifier-free guidance scale), must be a positive integer or 0")
      steps: conint(ge=0) = Field(..., description="Number of steps")
      batch_size: conint(gt=0) = Field(..., description="Number of images to generate in a batch")

@app.post("https://www.sitepoint.com/")
async def generate_image(request: GenerateRequest):
      
      
      if request.height % 8 != 0 or request.width % 8 != 0:
            raise HTTPException(status_code=400, detail="Height and width must both be multiples of 8")

      
      
      generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(request.seed, request.seed + request.batch_size)]

      images = pipeline(
            height=request.height,
            width=request.width,
            prompt=request.prompt,
            negative_prompt=request.negative_prompt,
            generator=generator,
            num_inference_steps=request.steps,
            guidance_scale=request.cfg,
            num_images_per_prompt=request.batch_size
      ).images

      
      
      
      base64_images = []
      for image in images:
            buffered = BytesIO()
            image.save(buffered, format="PNG")
            img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
            base64_images.append(img_str)

      return {
            "images": base64_images,
      }

@app.on_event("startup")
async def startup_event():
      print("Image generation server running")

if __name__ == "__main__":
      import uvicorn
      uvicorn.run(app, host="0.0.0.0", port=8000)




Source link