ben-aaron188 | 287b30b | 2022-09-11 16:46:37 +0200 | [diff] [blame^] | 1 | #' Retrieves text embeddings for character input from a vector from the GPT-3 API |
| 2 | #' |
| 3 | #' @description |
| 4 | #' `gpt3_bunch_embedding()` extends the single embeddings function `gpt3_make_embedding()` to allow for the processing of a whole vector |
| 5 | #' @details The returned data.table contains the column `id` which indicates the text id (or its generic alternative if not specified) and the columns `dim_1` ... `dim_{max}`, where `max` is the length of the text embeddings vector that the four different models return. For the default "Ada" model, these are 1024 dimensions (i.e., `dim_1`... `dim_1024`). |
| 6 | #' |
| 7 | #' The function supports the text similarity embeddings for the four GPT-3 models as specified in the parameter list. The main difference between the four models is the sophistication of the embedding representation as indicated by the vector embedding size. |
| 8 | #' - Ada (1024 dimensions) |
| 9 | #' - Babbage (2048 dimensions) |
| 10 | #' - Curie (4096 dimensions) |
| 11 | #' - Davinci (12288 dimensions) |
| 12 | #' |
| 13 | #' Note that the dimension size (= vector length), speed and [associated costs](https://openai.com/api/pricing/) differ considerably. |
| 14 | #' |
| 15 | #' These vectors can be used for downstream tasks such as (vector) similarity calculations. |
| 16 | #' @param input_var character vector that contains the texts for which you want to obtain text embeddings from the GPT-3 model |
| 17 | #' #' @param id_var (optional) character vector that contains the user-defined ids of the prompts. See details. |
| 18 | #' @param param_model a character vector that indicates the [similarity embedding model](https://beta.openai.com/docs/guides/embeddings/similarity-embeddings); one of "text-similarity-ada-001" (default), "text-similarity-curie-001", "text-similarity-babbage-001", "text-similarity-davinci-001" |
| 19 | #' @return A data.table with the embeddings as separate columns; one row represents one input text. See details. |
| 20 | #' @examples |
| 21 | #' # First authenticate with your API key via `gpt3_authenticate('pathtokey')` |
| 22 | #' |
| 23 | #' # Use example data: |
| 24 | #' ## The data below were generated with the `gpt3_make_request()` function as follows: |
| 25 | #' ##### DO NOT RUN ##### |
| 26 | #' # travel_blog_data = gpt3_make_request(prompt_input = "Write a travel blog about a dog's journey through the UK:", temperature = 0.8, n = 10, max_tokens = 200)[[1]] |
| 27 | #' ##### END DO NOT RUN ##### |
| 28 | #' |
| 29 | #' # You can load these data with: |
| 30 | #' data("travel_blog_data") # the dataset contains 10 completions for the above request |
| 31 | #' |
| 32 | #' ## Obtain text embeddings for the completion texts: |
| 33 | #' emb_travelblogs = gpt3_bunch_embedding(input_var = travel_blog_data$gpt3) |
| 34 | #' dim(emb_travelblogs) |
| 35 | #' @export |
| 36 | gpt3_bunch_embedding = function(input_var |
ben-aaron188 | 3818e7c | 2022-09-08 17:49:01 +0200 | [diff] [blame] | 37 | , id_var |
| 38 | , param_model = 'text-similarity-ada-001'){ |
| 39 | |
ben-aaron188 | 287b30b | 2022-09-11 16:46:37 +0200 | [diff] [blame^] | 40 | data_length = length(input_var) |
| 41 | if(missing(id_var)){ |
| 42 | data_id = paste0('prompt_', 1:data_length) |
| 43 | } else { |
| 44 | data_id = id_var |
| 45 | } |
ben-aaron188 | 3818e7c | 2022-09-08 17:49:01 +0200 | [diff] [blame] | 46 | |
| 47 | empty_list = list() |
| 48 | |
| 49 | for(i in 1:data_length){ |
| 50 | |
| 51 | print(paste0('Embedding: ', i, '/', data_length)) |
| 52 | |
ben-aaron188 | 287b30b | 2022-09-11 16:46:37 +0200 | [diff] [blame^] | 53 | row_outcome = gpt3_make_embedding(model = param_model |
| 54 | , input = input_var[i]) |
ben-aaron188 | 3818e7c | 2022-09-08 17:49:01 +0200 | [diff] [blame] | 55 | |
| 56 | empty_df = data.frame(t(row_outcome)) |
| 57 | names(empty_df) = paste0('dim_', 1:length(row_outcome)) |
ben-aaron188 | 287b30b | 2022-09-11 16:46:37 +0200 | [diff] [blame^] | 58 | empty_df$id = data_id[i] |
ben-aaron188 | 3818e7c | 2022-09-08 17:49:01 +0200 | [diff] [blame] | 59 | |
| 60 | empty_list[[i]] = empty_df |
| 61 | |
| 62 | |
| 63 | } |
| 64 | |
| 65 | output_data = rbindlist(empty_list) |
| 66 | |
| 67 | return(output_data) |
| 68 | |
| 69 | } |