blob: 2e7c1674fd0aae9c28506cbfb2ef76e3585c2584 [file] [log] [blame]
ben-aaron188287b30b2022-09-11 16:46:37 +02001#' Retrieves text embeddings for character input from a vector from the GPT-3 API
2#'
3#' @description
ben-aaron1882b89c2a2022-09-11 16:54:25 +02004#' `gpt3_embeddings()` extends the single embeddings function `gpt3_single_embedding()` to allow for the processing of a whole vector
ben-aaron188287b30b2022-09-11 16:46:37 +02005#' @details The returned data.table contains the column `id` which indicates the text id (or its generic alternative if not specified) and the columns `dim_1` ... `dim_{max}`, where `max` is the length of the text embeddings vector that the four different models return. For the default "Ada" model, these are 1024 dimensions (i.e., `dim_1`... `dim_1024`).
6#'
7#' The function supports the text similarity embeddings for the four GPT-3 models as specified in the parameter list. The main difference between the four models is the sophistication of the embedding representation as indicated by the vector embedding size.
8#' - Ada (1024 dimensions)
9#' - Babbage (2048 dimensions)
10#' - Curie (4096 dimensions)
11#' - Davinci (12288 dimensions)
12#'
13#' Note that the dimension size (= vector length), speed and [associated costs](https://openai.com/api/pricing/) differ considerably.
14#'
15#' These vectors can be used for downstream tasks such as (vector) similarity calculations.
16#' @param input_var character vector that contains the texts for which you want to obtain text embeddings from the GPT-3 model
17#' #' @param id_var (optional) character vector that contains the user-defined ids of the prompts. See details.
18#' @param param_model a character vector that indicates the [similarity embedding model](https://beta.openai.com/docs/guides/embeddings/similarity-embeddings); one of "text-similarity-ada-001" (default), "text-similarity-curie-001", "text-similarity-babbage-001", "text-similarity-davinci-001"
19#' @return A data.table with the embeddings as separate columns; one row represents one input text. See details.
20#' @examples
21#' # First authenticate with your API key via `gpt3_authenticate('pathtokey')`
22#'
23#' # Use example data:
ben-aaron1882b89c2a2022-09-11 16:54:25 +020024#' ## The data below were generated with the `gpt3_single_request()` function as follows:
ben-aaron188287b30b2022-09-11 16:46:37 +020025#' ##### DO NOT RUN #####
ben-aaron1882b89c2a2022-09-11 16:54:25 +020026#' # travel_blog_data = gpt3_single_request(prompt_input = "Write a travel blog about a dog's journey through the UK:", temperature = 0.8, n = 10, max_tokens = 200)[[1]]
ben-aaron188287b30b2022-09-11 16:46:37 +020027#' ##### END DO NOT RUN #####
28#'
29#' # You can load these data with:
30#' data("travel_blog_data") # the dataset contains 10 completions for the above request
31#'
32#' ## Obtain text embeddings for the completion texts:
ben-aaron1882b89c2a2022-09-11 16:54:25 +020033#' emb_travelblogs = gpt3_embeddings(input_var = travel_blog_data$gpt3)
ben-aaron188287b30b2022-09-11 16:46:37 +020034#' dim(emb_travelblogs)
35#' @export
ben-aaron1882b89c2a2022-09-11 16:54:25 +020036gpt3_embeddings = function(input_var
ben-aaron1883818e7c2022-09-08 17:49:01 +020037 , id_var
38 , param_model = 'text-similarity-ada-001'){
39
ben-aaron188287b30b2022-09-11 16:46:37 +020040 data_length = length(input_var)
41 if(missing(id_var)){
42 data_id = paste0('prompt_', 1:data_length)
43 } else {
44 data_id = id_var
45 }
ben-aaron1883818e7c2022-09-08 17:49:01 +020046
47 empty_list = list()
48
49 for(i in 1:data_length){
50
51 print(paste0('Embedding: ', i, '/', data_length))
52
ben-aaron1882b89c2a2022-09-11 16:54:25 +020053 row_outcome = gpt3_single_embedding(model = param_model
ben-aaron188287b30b2022-09-11 16:46:37 +020054 , input = input_var[i])
ben-aaron1883818e7c2022-09-08 17:49:01 +020055
56 empty_df = data.frame(t(row_outcome))
57 names(empty_df) = paste0('dim_', 1:length(row_outcome))
ben-aaron188287b30b2022-09-11 16:46:37 +020058 empty_df$id = data_id[i]
ben-aaron1883818e7c2022-09-08 17:49:01 +020059
60 empty_list[[i]] = empty_df
61
62
63 }
64
ben-aaron188492669a2022-10-24 19:11:13 +020065 output_data = data.table::rbindlist(empty_list)
ben-aaron1883818e7c2022-09-08 17:49:01 +020066
67 return(output_data)
68
69}