Question
PyTorch Model Performs Poorly in Flask Server Compared to Local Execution
I have a PyTorch model that generates images with high quality and speed when I run it locally. However, when I deploy the same code in a Flask server, the generated images are of much lower quality and the process is extremely slow.
Details:
Local Environment:
OS: Windows 10
Python Version: 3.12
PyTorch Version: (2.3.1+cu121)
GPU:
Server Environment:
Deployment: Flask
Server: waitress
Hosting: local
Resources: GPU
Code Snippet (alone):
from diffusers import StableDiffusionPipeline
from torch import float16
pipeline = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', torch_dtype=float16)
pipeline.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
pipeline.to('cuda')
pipeline.enable_attention_slicing()
pipeline('best quality, high quality, photorealistic, an astronaut riding a white horse in space ', num_inference_steps=20, negative_prompt='bad quality, low quality', num_images_per_prompt=1, height=800, width=1000).images[0].save('1.png')
Result (less than 1 minute):
Code Snippet (flask):
from flask import Flask
from diffusers import StableDiffusionPipeline
from torch import float16
app = Flask(__name__)
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=float16) # stabilityai/sdxl-turbo
pipeline.to("cuda")
pipeline.enable_attention_slicing()
@app.route('/texttoimage/generate', methods=['POST'])
def ttig():
global count
if eval(request.cookies['season']) in (user := info[request.cookies['name'].encode()])[1]:
user = user[4]
pipeline.safety_checker = lambda images, **kwargs: (images, [False if request.args['NSFW'] == 'true' else True] * len(images))
images = pipeline(request.args['prompt'], negative_prompt=request.args.get('negative'), num_inference_steps=int(request.args['quality']), num_images_per_prompt=int(request.args['count']), height=int(request.args['height']) // 8 * 8, width=int(request.args['width']) // 8 * 8).images
for k,j in enumerate(images):
user[f"{count + k}.{request.args['type']}"] = 'a'
j.save(f"s/{count + k}.{request.args['type']}")
count += len(images)
return str(count)
if __name__ == '__main__':
from waitress import serve
serve(app, port=80)
Result (around 10 minutes):
[
Steps Tried:
- Ensured the model is loaded only once.
- Tested in both debug and non-debug modes of Flask.
- Verified the server has adequate resources.
Questions:
- What could be causing the slow performance and lower image quality in the Flask server?
- Are there any best practices for deploying PyTorch models with Flask to ensure optimal performance?
- Could the choice of WSGI server or server configuration be impacting the performance?
Any guidance or suggestions would be greatly appreciated!
Edit:
I tested the same thing on FastAPI and it was the same so it's not a problem with flaks thus I'll remove flask tag.