OpenAI with Guardrail

Use Case Overview

We have created a chatbot, dividing it into two sides: the user side and the admin side. This is a simulation where the user interacts with the chatbot by asking questions, and the admin is the owner or the one who defines the criteria for the chatbot's questions.

Users can ask and answer questions just like the previous chatbot we created. However, there is a difference; they can only ask questions within the specified criteria. For example, this chatbot is set to only respond to questions related to animals. If the user asks a question unrelated to animals, the chatbot will reply that it cannot answer that question.

On the admin side, there is the privilege to set the criteria for the number of questions, and it is possible to add and clear criteria in the admin section of the chatbot, which displays all the criteria currently in use.

User Flow

Admin

User

Frontend Development

Admin Interface for Rule Management

In designing the UI for the admin's rule management, the admin can create and delete rules and view the rules set for development. For the admin interface, React is used to manage various states. For form management, the library react-hook-form, @hookform/resolvers/zod, and zod are used for input validation. The connection to the API is handled using @tanstack/react-query, which provides various features for API interaction, such as using useQueries for fetching multiple values simultaneously.

const [{ data: session }, { data: rule, refetch: refetchRule }] = useQueries({
    queries: [
      {
        queryKey: ['session', endPoint],
        queryFn: async () => {
          const { data } = await axios.get(`/session`)
          return data
        },
        enabled: typeof window !== 'undefined' && !localStorage?.session,
      },
      {
        queryKey: ['rule', endPoint],
        queryFn: async () => {
          const { data } = await axios.get(`/rule`)
          return data
        },
      },
    ],
  })

For adding and clearing rules, useMutation is used. The advantage of this library is that it handles various states automatically. For example, after successfully adding or clearing a rule, it triggers the refetchRule function to update the table.

const { mutateAsync: createRule } = useMutation({
    mutationFn: (rule: string) => {
      return axios.post('/create_rule', { rule })
    },
    onSuccess: () => {
      refetchRule()
    },
  })

  const { mutateAsync: clearRule } = useMutation({
    mutationFn: () => {
      return axios.get('/clear_collection')
    },
    onSuccess: () => {
      refetchRule()
    },
  })

User Interface for Question Submission

In the implementation of the chatbot interface, there is an input section for typing messages to the AI. When a user types something unrelated to the specified rules, the API responds that it is a "bad prompt" and prompts the user to ask another question.

Now, let's see what happens when we ask a question related to cats. The AI responds by asking about the cat we inquired about. The interaction with the chatbot is done in a streaming format.

useEffect(() => {
    const message = { query: watch('query') }
    const getData = async () => {
      try {
        setValue('query', '')
        const response = await fetch(
          `${endPoint ?? API_URL}/query?uuid=${localStorage?.session}&message=${
            message.query
          }`,
          {
            method: 'GET',
            headers: {
              Accept: 'text/event-stream',
              'x-api-key': localStorage?.apiKey,
            },
          }
        )

        if (response.status === 200) {
          const reader = response.body!.getReader()
          let result = ''
          while (true) {
            const { done, value } = await reader?.read()
            if (done) {
              setStreamText('')
              setAnswer((prevState) => [
                ...prevState,
                {
                  id: (prevState.length + 1).toString(),
                  role: 'ai',
                  message: result,
                },
              ])
              break
            }
            result += new TextDecoder().decode(value)
            setStreamText(result)
          }
        } else {
          if (response.status === 404) {
            localStorage.removeItem('session')
            refreshSession()
            setError('bot', {
              message: 'Something went wrong, please try again',
            })
            return
          }

          setError('bot', {
            message: 'Something went wrong',
          })
        }
      } catch (error: any) {
        console.error(error)
        setError('bot', {
          message: error?.response?.data?.message ?? 'Something went wrong',
        })
      }
    }
    if (isSubmitSuccessful && localStorage) {
      getData()
    }
  }, [
    submitCount,
    isSubmitSuccessful,
    setValue,
    watch,
    setError,
    endPoint,
    refreshSession,
  ])

Guardrail Interaction

The Guardrail feature limits the chatbot's ability to respond to prompts based on the rules set by the admin. The Relevant Answer Generation (RAG) checks how closely the prompt matches the rules and sends it to OpenAI. If the similarity is less than 99%, the API responds with a "bad prompt" message, indicating that the question is outside the defined rules.

Error Handling and User Feedback

Error handling is implemented for server connection issues or user input errors. For example, if there's a 404 status code (not found), the session is cleared and a new session is fetched. Validation errors for user input are handled using react-hook-form, @hookform/resolvers/zod, and zod.

export const askScheme = z.object({
  apiKey: z.string().optional(),
  query: z.string().trim().min(1, { message: 'Please enter your message' }),
  bot: z.string({}).optional(),
})

export const ruleScheme = z.object({
  rule: z.string().trim().min(1, { message: 'Please enter your rule' }),
})

export const endpointScheme = z.object({
  endpoint: z.string().url().optional(),
})

const methodRule = useForm<IRuleForm>({
    resolver: zodResolver(ruleScheme),
    mode: 'onChange',
    shouldFocusError: true,
  })

  const methodsEndpoint = useForm<IEndPointForm>({
    resolver: zodResolver(endpointScheme),
    mode: 'onChange',
    shouldFocusError: true,
  })

  const methods = useForm<IOpenAIForm>({
    resolver: zodResolver(askScheme),
    mode: 'onChange',
    shouldFocusError: true,
    defaultValues: {
      apiKey: '',
      query: '',
    },
  })

Link to Frontend Code

Backend Development

Setup FastAPI

Install the necessary Python 3 libraries for FastAPI:

pip install fastapi
pip install "uvicorn[standard]"

Create a file named guardrail.py and initialize FastAPI:

from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/helloworld")
async def helloworld():
    return {"message": "Hello World"}

Run the FastAPI server using the command:

uvicorn guardrail:app

This command instructs FastAPI, declared in the app.py file, to run, with the server defaulting to port 8000.

The API is divided into two parts: one for users sending prompts to the backend for interacting with AI, and another for administrators controlling whether prompts can be sent to AI.

Install other necessary dependencies:

pip install openai
pip install uuid
pip install pydantic
pip install langchain

The admin API includes functions for creating, viewing, and deleting rules for filtering prompts, essentially adding data to VectorDB. The VectorDB initialization will be discussed further.

def add_rule(text):
    Chroma.from_texts(collection_name=collection_name,texts=[text], embedding=embedding ,persist_directory=persist_dir)

@app.post('/create_rule')
def create_rule(rule:Rule):
    add_rule(rule.rule)
    return JSONResponse(content={
        "message":"success",
        "rule":rule.rule
        })

View all created rules:

def get_collection():
    vectorstore = init_db()
    result = vectorstore.get()
    return result['documents']

@app.get('/rule')
def get_rule():
    res = get_collection()
    return JSONResponse(content={"result": res})

Clear all rules:

def clear_db_collection():
    vectorstore = init_db()
    res = vectorstore.delete_collection()

@app.get('/clear_collection')
def clear_collection():
    clear_db_collection()
    return JSONResponse(content={"message": "remove complete"})

For user interaction, an API is provided to send prompts and check whether they match any rules:

@app.get("/session")
def session():
    client_uuid = uuid.uuid4()
    create_session(str(client_uuid))
    result = {
        "uuid" : str(client_uuid)
    }
    return JSONResponse(content=result)

@app.get("/query")
async def main(uuid:str,message:str):
    get_session(uuid)
    collection = get_collection()
    if  len(collection) == 0:
        return StreamingResponse(
                    generate_response('No rules configure please ask admin'), 
                    media_type="text/event-stream"
                )
    else:
        score = compare_similarity(message)
        print(score)
        if score > 0.7:
            return StreamingResponse(
                        stream_chat(
                            uuid = uuid,
                            prompt= message
                        ), 
                        media_type="text/event-stream"
                    )
        else:
            return StreamingResponse(
                        generate_response('bad prompt please ask another one'), 
                        media_type="text/event-stream"
                    )

Integrating OpenAI API

Connect to the OpenAI API using the API key generated from the OpenAI console. Two methods are provided: one using an Embedding model and the other using the Chatbot API.

  1. Embedding model:

from langchain.embeddings import OpenAIEmbeddings
api_key = 'sk-XcTnjgYVsJQMNxxxxxxxxxxxxxx'
embedding = OpenAIEmbeddings(openai_api_key=api_key)
  1. Chatbot API:

import openai
openai.api_key = api_key

def stream_chat(uuid: str, prompt: str):
    result = ""
    messages = add_message(uuid, 'user', prompt)
    for chunk in openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=messages,
        stream=True,
    ):
        content = chunk["choices"][0].get("delta", {}).get("content")
        if content is not None:
            result = result + content
            yield content
    add_message(uuid, 'assistant', result)

Implementing VectorDB with Chroma

Chroma is used for VectorDB, which stores vectorized data to aid in finding similar data. Initialization is done using parameters for collection name, data location, and the embedding model.

from langchain.vectorstores import Chroma

def init_db():
    vectorstore = Chroma(collection_name=collection_name, persist_directory=persist_dir, embedding_function=embedding)
    return vectorstore

Building the Guardrail Feature

The Guardrail feature queries rule data from VectorDB based on received prompts, and if the similarity score is above a threshold, the prompt is forwarded to the chatbot.

def compare_similarity(query):
    vectorstore = init_db()
    result = vectorstore.similarity_search_with_relevance_scores(query, k=5)
    score_list = []
    for i in result:
        score_list.append(i[-1])
    try:
        average = sum(score_list)/len(score_list)
        return average
    except:
        return 0

Deploying on AWS EC2

Deployment is done using the screen utility:

  1. Create a session with a specific name:

screen -S name
  1. Navigate to the API folder:

cd path/to/api
  1. Start FastAPI on port 8000:

uvicorn guardrail:app --host 0.0.0.0 --port 8000
  1. Detach from the current screen session:

Ctrl+a d

API Gateway and CORS Configuration

Add CORS configuration for FastAPI:

from fastapi.middleware.cors import CORSMiddleware

app.app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Connect this server to AWS API Gateway for authentication and usage management in each request.

Link to Backend Code

Last updated