Question

PyTorch Model Performs Poorly in Flask Server Compared to Local Execution

I have a PyTorch model that generates images with high quality and speed when I run it locally. However, when I deploy the same code in a Flask server, the generated images are of much lower quality and the process is extremely slow.

Details:

Local Environment:

OS: Windows 10

Python Version: 3.12

PyTorch Version: (2.3.1+cu121)

GPU:

GPU information

Server Environment:

Deployment: Flask

Server: waitress

Hosting: local

Resources: GPU

Code Snippet (alone):

from diffusers import StableDiffusionPipeline
from torch import float16
pipeline = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', torch_dtype=float16)
pipeline.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
pipeline.to('cuda')
pipeline.enable_attention_slicing()
pipeline('best quality, high quality, photorealistic, an astronaut riding a white horse in space ', num_inference_steps=20, negative_prompt='bad quality, low quality', num_images_per_prompt=1, height=800, width=1000).images[0].save('1.png')

Result (less than 1 minute):

perfect result

Code Snippet (flask):

from flask import Flask
from diffusers import StableDiffusionPipeline
from torch import float16
app = Flask(__name__)
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=float16) # stabilityai/sdxl-turbo
pipeline.to("cuda")
pipeline.enable_attention_slicing()
@app.route('/texttoimage/generate', methods=['POST'])
def ttig():
    global count
    if eval(request.cookies['season']) in (user := info[request.cookies['name'].encode()])[1]:
        user = user[4]
        pipeline.safety_checker = lambda images, **kwargs: (images, [False if request.args['NSFW'] == 'true' else True] * len(images))
        images = pipeline(request.args['prompt'], negative_prompt=request.args.get('negative'), num_inference_steps=int(request.args['quality']), num_images_per_prompt=int(request.args['count']), height=int(request.args['height']) // 8 * 8, width=int(request.args['width']) // 8 * 8).images
        for k,j in enumerate(images):
            user[f"{count + k}.{request.args['type']}"] = 'a'
            j.save(f"s/{count + k}.{request.args['type']}")
        count += len(images)
        return str(count)
if __name__ == '__main__':
    from waitress import serve
    serve(app, port=80)

Result (around 10 minutes):

[terrible result(https://i.sstatic.net/M6ghfzpB.jpg)

Steps Tried:

  • Ensured the model is loaded only once.
  • Tested in both debug and non-debug modes of Flask.
  • Verified the server has adequate resources.

Questions:

  1. What could be causing the slow performance and lower image quality in the Flask server?
  2. Are there any best practices for deploying PyTorch models with Flask to ensure optimal performance?
  3. Could the choice of WSGI server or server configuration be impacting the performance?

Any guidance or suggestions would be greatly appreciated!

Edit:

I tested the same thing on FastAPI and it was the same so it's not a problem with flaks thus I'll remove flask tag.

 6  168  6
1 Jan 1970

Solution

 0

There are couple of things that might lead to such different outputs. [EDITED]

  1. Check if you have installed GPU version of torch on server and your model is actually able to use it. Add the line print(torch.cuda.is_available()) in 4th line of code. It should print True. While the code is being executed, monitor GPU usage by repetitively giving cmd nvidia-smi on another terminal. GPU usage should increase when executing script on GPU.

  2. Compare environments locally and on server. Give commands - pip freeze>env_libs.txt on server and locally. It will save file env_libs.txt on server and on local. Compare both files. The packages listed in both files should be of same version ideally.

  3. Try to run the flask app using native app.run() function without using waitress server for debugging. If it works using native flask only, you can try Gunicorn to host your app instead of waitress. WSGI servers interfere with your resources and can prevent the flask app to run on GPU. You need to configure them in order to let flask use the GPU.

  4. Does the local script gives similar output when deployed on server without flask or different output? It should give similar output. if it doesn't give, then we need to solve it first before integrating it with flask.

  5. This step should be tried if your local script gives intended output on server as in step 4. Check if intended inputs are going to the flask model in a proper way. You can save the args to a JSON/csv/txt file before feeding it to pipeline. The output difference might be due to different color spaces. You might be missing some post processing step required before you save the model output image. Check if there is some change in output when you feed data to model in batches. Print shapes of both outputs on server and local to debug. Inspecting both output tensors can give some clues on post-processing required.

Another thing is you can try the server script locally by running it on localhost. Check its output on localhost. This will check if code is correct or there is some env related issue.

2024-07-19
YadneshD

Solution

 0

Flask by default runs in a single-threaded mode, which can limit the GPU utilization when handling multiple requests concurrently. To address this, consider using asynchronous processing with tools like asyncio or ThreadPoolExecutor to handle multiple requests concurrently.

While executing your code, you can monitor the GPU memory utilization using the command: nvidia-smi, which will help you check if the GPU is being used.

The following code uses the asyncio library which will speed up the code.

for k, j in enumerate(images):
    user[f"{count + k}.{request.args['type']}"] = 'a'
    tasks.append(asyncio.create_task(j.save(f"s/{count + k}.{request.args['type']}")))

asyncio.run(asyncio.wait(tasks))

It is a more useful tool for building models in web servers gradio, which is a Python library for building model web servers.

2024-07-20
Hui Xiao