Chat With Your SQL Database
Last Updated: September 24, 2024
In this example, we are querying a SQL Database!
Resources:
Install dependencies
For this demo, we’re using SQLite.
The first few code cells in this section fetchers a CSV file on ‘Absenteeism’ and creates a SQL table from it
!pip install git+https://github.com/deepset-ai/haystack.git@main#egg=haystack-ai
from urllib.request import urlretrieve
from zipfile import ZipFile
import pandas as pd
url = "https://archive.ics.uci.edu/static/public/445/absenteeism+at+work.zip"
# download the file
urlretrieve(url, "Absenteeism_at_work_AAA.zip")
print("Extracting the Absenteeism at work dataset...")
# Extract the CSV file
with ZipFile("Absenteeism_at_work_AAA.zip", 'r') as zf:
zf.extractall()
# Check the extracted CSV file name (in this case, it's "Absenteeism_at_work.csv")
csv_file_name = "Absenteeism_at_work.csv"
print("Cleaning up the Absenteeism at work dataset...")
# Data clean up
df = pd.read_csv(csv_file_name, sep=";")
df.columns = df.columns.str.replace(' ', '_')
df.columns = df.columns.str.replace('/', '_')
Extracting the Absenteeism at work dataset...
Cleaning up the Absenteeism at work dataset...
columns = df.columns.to_list()
columns = ', '.join(columns)
columns
'ID, Reason_for_absence, Month_of_absence, Day_of_the_week, Seasons, Transportation_expense, Distance_from_Residence_to_Work, Service_time, Age, Work_load_Average_day_, Hit_target, Disciplinary_failure, Education, Son, Social_drinker, Social_smoker, Pet, Weight, Height, Body_mass_index, Absenteeism_time_in_hours'
import sqlite3
connection = sqlite3.connect('absenteeism.db')
print("Opened database successfully");
connection.execute('''CREATE TABLE IF NOT EXISTS absenteeism (ID integer,
Reason_for_absence integer,
Month_of_absence integer,
Day_of_the_week integer,
Seasons integer,
Transportation_expense integer,
Distance_from_Residence_to_Work integer,
Service_time integer,
Age integer,
Work_load_Average_day_ integer,
Hit_target integer,
Disciplinary_failure integer,
Education integer,
Son integer,
Social_drinker integer,
Social_smoker integer,
Pet integer,
Weight integer,
Height integer,
Body_mass_index integer,
Absenteeism_time_in_hours integer);''')
connection.commit()
Opened database successfully
df.to_sql('absenteeism', connection, if_exists='replace', index = False)
740
connection.close()
Create a SQL Query Component
Here, we’re creating a custom component called SQLQuery
, this way, we can use it in our Haystack pipeline like any other component (like a retriever, generator etc). This component does just one thing:
- Accepts
queries
which are SQL queries - Queries the database with those SQL queries and returns the result from the database.
from typing import List
from haystack import component
@component
class SQLQuery:
def __init__(self, sql_database: str):
self.connection = sqlite3.connect(sql_database, check_same_thread=False)
@component.output_types(results=List[str], queries=List[str])
def run(self, queries: List[str]):
results = []
for query in queries:
result = pd.read_sql(query, self.connection)
results.append(f"{result}")
return {"results": results, "queries": queries}
Try the SQLQuery Component
sql_query = SQLQuery('absenteeism.db')
result = sql_query.run(queries=['SELECT Age, SUM(Absenteeism_time_in_hours) as Total_Absenteeism_Hours FROM absenteeism WHERE Disciplinary_failure = 0 GROUP BY Age ORDER BY Total_Absenteeism_Hours DESC LIMIT 3;'])
print(result["results"][0])
Age Total_Absenteeism_Hours
0 28 651
1 33 538
2 38 482
Query A SQL Database with Natural Language
In this section, we’re building a simple pipeline that can:
- Accept natural language questions
- Translates those questions into a SQL Query
- Queries our database using the
SQLQuery
component
Shortcoming: This pipeline will still run if you ask a completely unrelated question that cannot be answered with the database we have at hand. Observe how the SQLQuery
component throws an error in these cases.
import os
from getpass import getpass
os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")
OpenAI API Key: Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.generators.openai import OpenAIGenerator
prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
The query is to be answered for the table is called 'absenteeism' with the following
Columns: {{columns}};
Answer:""")
sql_query = SQLQuery('absenteeism.db')
llm = OpenAIGenerator(model="gpt-4")
sql_pipeline = Pipeline()
sql_pipeline.add_component("prompt", prompt)
sql_pipeline.add_component("llm", llm)
sql_pipeline.add_component("sql_querier", sql_query)
sql_pipeline.connect("prompt", "llm")
sql_pipeline.connect("llm.replies", "sql_querier.queries")
# If you want to draw the pipeline, uncomment below π
sql_pipeline.show()
result = sql_pipeline.run({"prompt": {"question": "On which days of the week does the average absenteeism time exceed 4 hours?",
"columns": columns}})
print(result["sql_querier"]["results"][0])
Day_of_the_week
0 2
1 3
2 4
3 5
4 6
Skip for Unrelated Questions: Add a Condition
Now, let’s create another pipeline, to avoid having to query the database if the question is unrelated to the information present in the database. For this, we do a few things:
- We modify the prompt to answer with
no_answer
if the question cannot be answered given the database and itscolumns
- We add a conditional router that routes the query to the
SQLQuery
component only if the question was evaluated to be answerable - We add a
fallback_prompt
andfallback_llm
to return a statement about the fact that the question cannot be answered, along with reasons. This branch of the pipeline runs only if the question cannot be answered.
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.generators.openai import OpenAIGenerator
from haystack.components.routers import ConditionalRouter
prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
If the question cannot be answered given the provided table and columns, return 'no_answer'
The query is to be answered for the table is called 'absenteeism' with the following
Columns: {{columns}};
Answer:""")
llm = OpenAIGenerator(model="gpt-4")
sql_query = SQLQuery('absenteeism.db')
routes = [
{
"condition": "{{'no_answer' not in replies[0]}}",
"output": "{{replies}}",
"output_name": "sql",
"output_type": List[str],
},
{
"condition": "{{'no_answer' in replies[0]}}",
"output": "{{question}}",
"output_name": "go_to_fallback",
"output_type": str,
},
]
router = ConditionalRouter(routes)
fallback_prompt = PromptBuilder(template="""User entered a query that cannot be answerwed with the given table.
The query was: {{question}} and the table had columns: {{columns}}.
Let the user know why the question cannot be answered""")
fallback_llm = OpenAIGenerator(model="gpt-4")
conditional_sql_pipeline = Pipeline()
conditional_sql_pipeline.add_component("prompt", prompt)
conditional_sql_pipeline.add_component("llm", llm)
conditional_sql_pipeline.add_component("router", router)
conditional_sql_pipeline.add_component("fallback_prompt", fallback_prompt)
conditional_sql_pipeline.add_component("fallback_llm", fallback_llm)
conditional_sql_pipeline.add_component("sql_querier", sql_query)
conditional_sql_pipeline.connect("prompt", "llm")
conditional_sql_pipeline.connect("llm.replies", "router.replies")
conditional_sql_pipeline.connect("router.sql", "sql_querier.queries")
conditional_sql_pipeline.connect("router.go_to_fallback", "fallback_prompt.question")
conditional_sql_pipeline.connect("fallback_prompt", "fallback_llm")
#if you want to draw the pipeline, uncomment below π
#conditional_sql_pipeline.show()
question = "When is my birthday?"
result = conditional_sql_pipeline.run({"prompt": {"question": question,
"columns": columns},
"router": {"question": question},
"fallback_prompt": {"columns": columns}})
if 'sql_querier' in result:
print(result['sql_querier']['results'][0])
elif 'fallback_llm' in result:
print(result['fallback_llm']['replies'][0])
The query cannot be answered as the provided table does not contain information regarding the user's personal data such as birthdays. The table primarily focuses on absence-related data for presumably work or similar situations. Please provide the relevant data to get the accurate answer.
Function Calling to Query a SQL Database
Now let’s try something a bit more fun. Instead of a component, we are going to procide sql querying as a function. Since we already built it, we can simply wrap our SQLQuery
component into a function π
sql_query = SQLQuery('absenteeism.db')
def sql_query_func(queries: List[str]):
try:
result = sql_query.run(queries)
return {"reply": result["results"][0]}
except Exception as e:
reply = f"""There was an error running the SQL Query = {queries}
The error is {e},
You should probably try again.
"""
return {"reply": reply}
Define Tools
Now, let’s provide this function as a tool. Below, we are using OpenAI for demonstration purposes so we abide by their function definition schema π
tools = [
{
"type": "function",
"function": {
"name": "sql_query_func",
"description": f"This a tool useful to query a SQL table called 'absenteeism' with the following Columns: {columns}",
"parameters": {
"type": "object",
"properties": {
"queries": {
"type": "array",
"description": "The query to use in the search. Infer this from the user's message. It should be a question or a statement",
"items": {
"type": "string",
}
}
},
"required": ["question"],
},
},
}
]
Try The Tool
from haystack.dataclasses import ChatMessage
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk
messages = [
ChatMessage.from_system(
"You are a helpful and knowledgeable agent who has access to an SQL database which has a table called 'absenteeism'"
),
ChatMessage.from_user("On which days of the week does the average absenteeism time exceed 4 hours??"),
]
chat_generator = OpenAIChatGenerator(model="gpt-4", streaming_callback=print_streaming_chunk)
response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools})
print(response)
{'replies': [ChatMessage(content='[{"index": 0, "id": "call_fRYwYg6iAqroHwYzPD6UxOVg", "function": {"arguments": "{\\n \\"queries\\": [\\"SELECT Day_of_the_week, AVG(Absenteeism_time_in_hours) AS Average_Absenteeism_Hours FROM absenteeism GROUP BY Day_of_the_week HAVING AVG(Absenteeism_time_in_hours) > 4\\"]\\n}", "name": "sql_query_func"}, "type": "function"}]', role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-0613', 'index': 0, 'finish_reason': 'tool_calls', 'usage': {}})]}
import json
## Parse function calling information
function_call = json.loads(response["replies"][0].content)[0]
function_name = function_call["function"]["name"]
function_args = json.loads(function_call["function"]["arguments"])
print("Function Name:", function_name)
print("Function Arguments:", function_args)
## Find the correspoding function and call it with the given arguments
available_functions = {"sql_query_func": sql_query_func}
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args)
print("Function Response:", function_response)
Function Name: sql_query_func
Function Arguments: {'queries': ['SELECT Day_of_the_week, AVG(Absenteeism_time_in_hours) AS Average_Absenteeism_Hours FROM absenteeism GROUP BY Day_of_the_week HAVING AVG(Absenteeism_time_in_hours) > 4']}
Function Response: {'reply': ' Day_of_the_week Average_Absenteeism_Hours\n0 2 9.248447\n1 3 7.980519\n2 4 7.147436\n3 5 4.424000\n4 6 5.125000'}
Build a Chat with SQL App
First, let’s install Gradio, we will use that for our mini app
!pip install gradio
import gradio as gr
import json
from haystack.dataclasses import ChatMessage
from haystack.components.generators.chat import OpenAIChatGenerator
chat_generator = OpenAIChatGenerator(model="gpt-4")
response = None
messages = [
ChatMessage.from_system(
"You are a helpful and knowledgeable agent who has access to an SQL database which has a table called 'absenteeism'"
)
]
def chatbot_with_fc(message, history):
available_functions = {"sql_query_func": sql_query_func}
messages.append(ChatMessage.from_user(message))
response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools})
while True:
# if OpenAI response is a tool call
if response and response["replies"][0].meta["finish_reason"] == "tool_calls":
function_calls = json.loads(response["replies"][0].content)
for function_call in function_calls:
## Parse function calling information
function_name = function_call["function"]["name"]
function_args = json.loads(function_call["function"]["arguments"])
## Find the correspoding function and call it with the given arguments
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args)
## Append function response to the messages list using `ChatMessage.from_function`
messages.append(ChatMessage.from_function(content=function_response['reply'], name=function_name))
response = chat_generator.run(messages=messages, generation_kwargs={"tools": tools})
# Regular Conversation
else:
messages.append(response["replies"][0])
break
return response["replies"][0].content
demo = gr.ChatInterface(
fn=chatbot_with_fc,
examples=[
"Find the top 3 ages with the highest total absenteeism hours, excluding disciplinary failures",
"On which days of the week does the average absenteeism time exceed 4 hours?",
"Who lives in London?",
],
title="Chat with your SQL Database",
)
demo.launch(debug=True)
Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://83eb0414c1916d8ee7.gradio.live
This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
ChatMessage(content='[{"id": "call_Uu8QXlIsJfYCULD4Q0bEcAtP", "function": {"arguments": "{\\n \\"queries\\": [\\n \\"SELECT Age, SUM(Absenteeism_time_in_hours) as Total_Absenteeism_Hours FROM absenteeism WHERE Disciplinary_failure = 0 GROUP BY Age ORDER BY Total_Absenteeism_Hours DESC LIMIT 3\\"\\n ]\\n}", "name": "sql_query_func"}, "type": "function"}]', role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-0613', 'index': 0, 'finish_reason': 'tool_calls', 'usage': {'completion_tokens': 68, 'prompt_tokens': 207, 'total_tokens': 275}})
Age Total_Absenteeism_Hours
0 28 651
1 33 538
2 38 482
ChatMessage(content='[{"id": "call_t8bjUHMvHHrXReB2qm3iVkNF", "function": {"arguments": "{\\n\\"queries\\": [\\"SELECT Day_of_the_week, AVG(Absenteeism_time_in_hours) as average_absenteeism_time FROM absenteeism GROUP BY Day_of_the_week HAVING average_absenteeism_time > 4\\"]\\n}", "name": "sql_query_func"}, "type": "function"}]', role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-0613', 'index': 0, 'finish_reason': 'tool_calls', 'usage': {'completion_tokens': 57, 'prompt_tokens': 320, 'total_tokens': 377}})
Day_of_the_week average_absenteeism_time
0 2 9.248447
1 3 7.980519
2 4 7.147436
3 5 4.424000
4 6 5.125000
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://83eb0414c1916d8ee7.gradio.live