system HF Staff commited on
Commit
930e07a
·
0 Parent(s):

initial commit

Browse files
Files changed (5) hide show
  1. .gitattributes +27 -0
  2. README.md +37 -0
  3. app.py +61 -0
  4. models.pt +3 -0
  5. requirements.txt +3 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Nlp_course_hw_style_transfer
3
+ emoji: 👁
4
+ colorFrom: purple
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # Configuration
12
+
13
+ `title`: _string_
14
+ Display title for the Space
15
+
16
+ `emoji`: _string_
17
+ Space emoji (emoji-only character allowed)
18
+
19
+ `colorFrom`: _string_
20
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
+
22
+ `colorTo`: _string_
23
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
+
25
+ `sdk`: _string_
26
+ Can be either `gradio` or `streamlit`
27
+
28
+ `sdk_version` : _string_
29
+ Only applicable for `streamlit` SDK.
30
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
+
32
+ `app_file`: _string_
33
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
+ Path is relative to the root of the repository.
35
+
36
+ `pinned`: _boolean_
37
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ st.set_page_config(page_title='Review improver', layout='centered')
7
+
8
+ st.markdown('Hello!')
9
+
10
+
11
+ # load the necessary models
12
+ bert_mlm_positive, bert_mlm_negative, bert_classifier, tokenizer = torch.load('models.pt', map_location='cpu')
13
+
14
+
15
+ def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
16
+ inputs = tokenizer(sentence, return_tensors='pt')
17
+ tokens = inputs['input_ids'][0]
18
+
19
+ vocab_logits_positive = bert_mlm_positive(**inputs)['logits'][0]
20
+ vocab_probs_positive = F.softmax(vocab_logits_positive, dim=1)
21
+ probs_positive = vocab_probs_positive[torch.arange(len(tokens)), tokens]
22
+
23
+ vocab_logits_negative = bert_mlm_negative(**inputs)['logits'][0]
24
+ vocab_probs_negative = F.softmax(vocab_logits_negative, dim=1)
25
+ probs_negative = vocab_probs_negative[torch.arange(len(tokens)), tokens]
26
+
27
+ ratio = (probs_positive + epsilon) / (probs_negative + epsilon)
28
+ smallest_ratio_ids = torch.argsort(ratio)[:num_tokens]
29
+
30
+ replacements = []
31
+ for idx in smallest_ratio_ids:
32
+ new_tokens = torch.argsort(vocab_probs_positive[idx])[-k_best:]
33
+ for token in new_tokens:
34
+ cur_replacement = tokens.clone()
35
+ cur_replacement[idx] = token
36
+ replacements.append(cur_replacement)
37
+
38
+ replacements = [tokenizer.decode(replacement, skip_special_tokens=True) for replacement in replacements]
39
+
40
+ return replacements
41
+
42
+
43
+ def modify_sentence(sentence, num_iters=3):
44
+ for _ in range(num_iters):
45
+ replacements = get_replacements(sentence, num_tokens=3, k_best=5)
46
+
47
+ classifier_inputs = tokenizer(replacements, padding=True, return_tensors='pt')
48
+ logits = bert_classifier(**classifier_inputs)['logits'][:, 1]
49
+
50
+ best_idx = torch.argmax(logits)
51
+ sentence = replacements[best_idx]
52
+
53
+ return sentence
54
+
55
+
56
+ # here we will try to improve the review
57
+ user_input = st.text_input('Enter your review here and we will try to improve it:')
58
+ if user_input:
59
+ improved_review = modify_sentence(user_input)
60
+ st.markdown('Here is your improved review:')
61
+ st.markdown(improved_review)
models.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b01d9dfab0f7d5c49a5d84976384e24d1108b4fc04f500b79ff5fbfec48faf82
3
+ size 1315214587
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers