blob: 44fd78fc6fa52b707720085cea774715ad1257be [file] [log] [blame]
ben-aaron1883818e7c2022-09-08 17:49:01 +02001gpt3.bunch_request = function(data
2 , prompt_var
3 , completion_var_name = 'gpt3_completion'
4 , param_model = 'text-davinci-002'
5 , param_suffix = NULL
6 , param_max_tokens = 256
7 , param_temperature = 0.9
8 , param_top_p = 1
9 , param_n = 1
10 , param_stream = F
11 , param_logprobs = NULL
12 , param_echo = F
13 , param_stop = NULL
14 , param_presence_penalty = 0
15 , param_frequency_penalty = 0
16 , param_best_of = 1
17 , param_logit_bias = NULL){
18
19
20 data_ = data
21
22 data_length = data_[, .N]
23
24 data_[, completion_name := '']
25
26
27 for(i in 1:data_length){
28
29 print(paste0('Request: ', i, '/', data_length))
30
31 row_outcome = gpt3.make_request(prompt = as.character(unname(data_[i, ..prompt_var]))
32 , model = param_model
33 , output_type = 'detail'
34 , suffix = param_suffix
35 , max_tokens = param_max_tokens
36 , temperature = param_temperature
37 , top_p = param_top_p
38 , n = param_n
39 , stream = param_stream
40 , logprobs = param_logprobs
41 , echo = param_echo
42 , stop = param_stop
43 , presence_penalty = param_presence_penalty
44 , frequency_penalty = param_frequency_penalty
45 , best_of = param_best_of
46 , logit_bias = param_logit_bias)
47
48
49 data_$completion_name[i] = row_outcome$choices[[1]]$text
50
51
52 }
53
54 data_cols = ncol(data_)
55 names(data_)[data_cols] = completion_var_name
56
57 return(data_)
58}