shivrajanand commited on
Commit
e8f4897
·
verified ·
1 Parent(s): 7d97c1d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. LICENSE.md +201 -0
  3. README.md +52 -3
  4. data/ud_pos_ner_dp_dev_san +0 -0
  5. data/ud_pos_ner_dp_dev_san_POS +0 -0
  6. data/ud_pos_ner_dp_dev_san_case +0 -0
  7. data/ud_pos_ner_dp_test_san +0 -0
  8. data/ud_pos_ner_dp_test_san_POS +0 -0
  9. data/ud_pos_ner_dp_test_san_case +0 -0
  10. data/ud_pos_ner_dp_train_san +0 -0
  11. data/ud_pos_ner_dp_train_san_POS +0 -0
  12. data/ud_pos_ner_dp_train_san_case +0 -0
  13. examples/GraphParser.py +703 -0
  14. examples/GraphParser_MRL.py +603 -0
  15. examples/SequenceTagger.py +597 -0
  16. examples/eval/conll03eval.v2 +336 -0
  17. examples/eval/conll06eval.pl +1826 -0
  18. examples/test_original_dcst.sh +110 -0
  19. run_san_LCM.sh +73 -0
  20. utils/__init__.py +7 -0
  21. utils/__pycache__/__init__.cpython-37.pyc +0 -0
  22. utils/__pycache__/load_word_embeddings.cpython-37.pyc +0 -0
  23. utils/io_/__init__.py +5 -0
  24. utils/io_/__pycache__/__init__.cpython-37.pyc +0 -0
  25. utils/io_/__pycache__/alphabet.cpython-37.pyc +0 -0
  26. utils/io_/__pycache__/instance.cpython-37.pyc +0 -0
  27. utils/io_/__pycache__/logger.cpython-37.pyc +0 -0
  28. utils/io_/__pycache__/prepare_data.cpython-37.pyc +0 -0
  29. utils/io_/__pycache__/reader.cpython-37.pyc +0 -0
  30. utils/io_/__pycache__/rearrange_splits.cpython-37.pyc +0 -0
  31. utils/io_/__pycache__/seeds.cpython-37.pyc +0 -0
  32. utils/io_/__pycache__/write_extra_labels.cpython-37.pyc +0 -0
  33. utils/io_/__pycache__/writer.cpython-37.pyc +0 -0
  34. utils/io_/alphabet.py +147 -0
  35. utils/io_/coarse_to_ma_dict.json +1 -0
  36. utils/io_/convert_ud_to_onto_format.py +74 -0
  37. utils/io_/instance.py +19 -0
  38. utils/io_/logger.py +15 -0
  39. utils/io_/prepare_data.py +397 -0
  40. utils/io_/reader.py +93 -0
  41. utils/io_/rearrange_splits.py +68 -0
  42. utils/io_/remove_xx.py +60 -0
  43. utils/io_/seeds.py +12 -0
  44. utils/io_/write_extra_labels.py +1592 -0
  45. utils/io_/writer.py +46 -0
  46. utils/load_word_embeddings.py +123 -0
  47. utils/models/__init__.py +3 -0
  48. utils/models/__pycache__/__init__.cpython-37.pyc +0 -0
  49. utils/models/__pycache__/parsing.cpython-37.pyc +0 -0
  50. utils/models/__pycache__/parsing_gating.cpython-37.pyc +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ saved_models
2
+ multilingual_word_embeddings
LICENSE.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2021 Jivnesh Sandhan
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,52 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Official code for the paper ["A Little Pretraining Goes a Long Way: A Case Study on Dependency Parsing Task for Low-resource Morphologically Rich Languages"](https://arxiv.org/abs/2102.06551).
2
+ If you use this code please cite our paper.
3
+
4
+ ## Requirements
5
+
6
+ * Python 3.7
7
+ * Pytorch 1.1.0
8
+ * Cuda 9.0
9
+ * Gensim 3.8.1
10
+
11
+ We assume that you have installed conda beforehand.
12
+
13
+ ```
14
+ conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=9.0 -c pytorch
15
+ pip install gensim==3.8.1
16
+ ```
17
+ ## Data
18
+ * Pretrained FastText embeddings for Sanskrit can be obtained from [here](https://drive.google.com/drive/folders/1JJMBjUZdqUY7WLYefBbA2zKaMHH3Mm18?usp=sharing). Make sure that `.vec` file is placed at approprite position.
19
+ * For Multilingual experiments, we use [UD treebanks](https://universaldependencies.org/) and [Pretrained FastText embeddings](https://fasttext.cc/docs/en/crawl-vectors.html)
20
+
21
+
22
+ ## How to train model
23
+ If you want to run complete model pipeline: (1) Pretraining (2) Integration, then simply run bash script `run_san_LCM.sh`.
24
+
25
+ ```bash
26
+ bash run_san_LCM.sh
27
+
28
+ ```
29
+
30
+
31
+ ## Citation
32
+
33
+ If you use our tool, we'd appreciate if you cite the following paper:
34
+
35
+ ```
36
+ @inproceedings{sandhan-etal-2021-little,
37
+ title = "A Little Pretraining Goes a Long Way: A Case Study on Dependency Parsing Task for Low-resource Morphologically Rich Languages",
38
+ author = "Sandhan, Jivnesh and Krishna, Amrith and Gupta, Ashim and Behera, Laxmidhar and Goyal, Pawan",
39
+ booktitle = "Proceedings of the 16th Conference of the European Chapter of the Association for Computational Linguistics: Student Research Workshop",
40
+ month = apr,
41
+ year = "2021",
42
+ address = "Online",
43
+ publisher = "Association for Computational Linguistics",
44
+ url = "https://aclanthology.org/2021.eacl-srw.16",
45
+ doi = "10.18653/v1/2021.eacl-srw.16",
46
+ pages = "111--120",
47
+ abstract = "Neural dependency parsing has achieved remarkable performance for many domains and languages. The bottleneck of massive labelled data limits the effectiveness of these approaches for low resource languages. In this work, we focus on dependency parsing for morphological rich languages (MRLs) in a low-resource setting. Although morphological information is essential for the dependency parsing task, the morphological disambiguation and lack of powerful analyzers pose challenges to get this information for MRLs. To address these challenges, we propose simple auxiliary tasks for pretraining. We perform experiments on 10 MRLs in low-resource settings to measure the efficacy of our proposed pretraining method and observe an average absolute gain of 2 points (UAS) and 3.6 points (LAS).",
48
+ }
49
+ ```
50
+
51
+ ## Acknowledgements
52
+ Much of the base code is from ["DCST Implementation"](https://github.com/rotmanguy/DCST)
data/ud_pos_ner_dp_dev_san ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_dev_san_POS ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_dev_san_case ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_test_san ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_test_san_POS ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_test_san_case ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_train_san ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_train_san_POS ADDED
The diff for this file is too large to render. See raw diff
 
data/ud_pos_ner_dp_train_san_case ADDED
The diff for this file is too large to render. See raw diff
 
examples/GraphParser.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import sys
3
+ from os import path, makedirs
4
+
5
+ sys.path.append(".")
6
+ sys.path.append("..")
7
+
8
+ import argparse
9
+ from copy import deepcopy
10
+ import json
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from collections import namedtuple
15
+ from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits
16
+ from utils.models.parsing_gating import BiAffine_Parser_Gated
17
+ from utils import load_word_embeddings
18
+ from utils.tasks import parse
19
+ import time
20
+ from torch.nn.utils import clip_grad_norm_
21
+ from torch.optim import Adam, SGD
22
+ import uuid
23
+ import pdb
24
+ uid = uuid.uuid4().hex[:6]
25
+
26
+ logger = get_logger('GraphParser')
27
+
28
+ def read_arguments():
29
+ args_ = argparse.ArgumentParser(description='Sovling GraphParser')
30
+ args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
31
+ args_.add_argument('--domain', help='domain/language', required=True)
32
+ args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
33
+ required=True)
34
+ args_.add_argument('--gating',action='store_true', help='use gated mechanism')
35
+ args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism')
36
+ args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
37
+ args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
38
+ args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
39
+ args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
40
+ args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space')
41
+ args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
42
+ args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
43
+ args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
44
+ args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
45
+ args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
46
+ args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
47
+ args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
48
+ args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
49
+ args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
50
+ args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
51
+ args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
52
+ args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
53
+ args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
54
+ args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
55
+ args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
56
+ args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
57
+ args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
58
+ args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
59
+ args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
60
+ args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
61
+ args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
62
+ args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True)
63
+ args_.add_argument('--unk_replace', type=float, default=0.,
64
+ help='The rate to replace a singleton word with UNK')
65
+ args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
66
+ args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
67
+ help='Embedding for words')
68
+ args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
69
+ args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
70
+ args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.')
71
+ args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
72
+ required=True)
73
+ args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
74
+ required=True)
75
+ args_.add_argument('--char_path', help='path for character embedding dict')
76
+ args_.add_argument('--pos_path', help='path for pos embedding dict')
77
+ args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples')
78
+ args_.add_argument('--model_path', help='path for saving model file.', required=True)
79
+ args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
80
+ args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None)
81
+ args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin '
82
+ 'exactly the same keys as current model')
83
+ args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
84
+ args_.add_argument('--eval_with_CI', action='store_true', help='evaluating model in constrained inference mode')
85
+ args_.add_argument('--LCM_Path_flag', action='store_true', help='for constrained inference with LCM, flag is used to change path')
86
+ args = args_.parse_args()
87
+ args_dict = {}
88
+ args_dict['dataset'] = args.dataset
89
+ args_dict['domain'] = args.domain
90
+ args_dict['rnn_mode'] = args.rnn_mode
91
+ args_dict['gating'] = args.gating
92
+ args_dict['num_gates'] = args.num_gates
93
+ args_dict['arc_decode'] = args.arc_decode
94
+ # args_dict['splits'] = ['train', 'dev', 'test']
95
+ args_dict['splits'] = ['train', 'dev', 'test']
96
+ args_dict['model_path'] = args.model_path
97
+ if not path.exists(args_dict['model_path']):
98
+ makedirs(args_dict['model_path'])
99
+ args_dict['data_paths'] = {}
100
+ if args_dict['dataset'] == 'ontonotes':
101
+ data_path = 'data/onto_pos_ner_dp'
102
+ else:
103
+ data_path = 'data/ud_pos_ner_dp'
104
+ for split in args_dict['splits']:
105
+ args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain']
106
+ ###################################
107
+
108
+ ###################################
109
+ args_dict['alphabet_data_paths'] = {}
110
+ for split in args_dict['splits']:
111
+ if args_dict['dataset'] == 'ontonotes':
112
+ args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all'
113
+ else:
114
+ if '_' in args_dict['domain']:
115
+ args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
116
+ else:
117
+ args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split]
118
+ args_dict['model_name'] = 'domain_' + args_dict['domain']
119
+ args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
120
+ args_dict['load_path'] = args.load_path
121
+ args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths
122
+ if args_dict['load_sequence_taggers_paths'] is not None:
123
+ args_dict['gating'] = True
124
+ args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1
125
+ else:
126
+ if not args_dict['gating']:
127
+ args_dict['num_gates'] = 0
128
+ args_dict['strict'] = args.strict
129
+ args_dict['num_epochs'] = args.num_epochs
130
+ args_dict['batch_size'] = args.batch_size
131
+ args_dict['hidden_size'] = args.hidden_size
132
+ args_dict['arc_space'] = args.arc_space
133
+ args_dict['arc_tag_space'] = args.arc_tag_space
134
+ args_dict['num_layers'] = args.num_layers
135
+ args_dict['num_filters'] = args.num_filters
136
+ args_dict['kernel_size'] = args.kernel_size
137
+ args_dict['learning_rate'] = args.learning_rate
138
+ args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
139
+ args_dict['opt'] = args.opt
140
+ args_dict['momentum'] = args.momentum
141
+ args_dict['betas'] = tuple(args.betas)
142
+ args_dict['epsilon'] = args.epsilon
143
+ args_dict['decay_rate'] = args.decay_rate
144
+ args_dict['clip'] = args.clip
145
+ args_dict['gamma'] = args.gamma
146
+ args_dict['schedule'] = args.schedule
147
+ args_dict['p_rnn'] = tuple(args.p_rnn)
148
+ args_dict['p_in'] = args.p_in
149
+ args_dict['p_out'] = args.p_out
150
+ args_dict['unk_replace'] = args.unk_replace
151
+ args_dict['set_num_training_samples'] = args.set_num_training_samples
152
+ args_dict['punct_set'] = None
153
+ if args.punct_set is not None:
154
+ args_dict['punct_set'] = set(args.punct_set)
155
+ logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
156
+ args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
157
+ args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers
158
+ args_dict['word_embedding'] = args.word_embedding
159
+ args_dict['word_path'] = args.word_path
160
+ args_dict['use_char'] = args.use_char
161
+ args_dict['char_embedding'] = args.char_embedding
162
+ args_dict['char_path'] = args.char_path
163
+ args_dict['pos_embedding'] = args.pos_embedding
164
+ args_dict['pos_path'] = args.pos_path
165
+ args_dict['use_pos'] = args.use_pos
166
+ args_dict['pos_dim'] = args.pos_dim
167
+ args_dict['word_dict'] = None
168
+ args_dict['word_dim'] = args.word_dim
169
+ if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
170
+ args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
171
+ args_dict['word_path'])
172
+ args_dict['char_dict'] = None
173
+ args_dict['char_dim'] = args.char_dim
174
+ if args_dict['char_embedding'] != 'random':
175
+ args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
176
+ args_dict['char_path'])
177
+ args_dict['pos_dict'] = None
178
+ if args_dict['pos_embedding'] != 'random':
179
+ args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
180
+ args_dict['pos_path'])
181
+ args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
182
+ args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
183
+ args_dict['eval_mode'] = args.eval_mode
184
+ args_dict['eval_with_CI'] = args.eval_with_CI
185
+ args_dict['LCM_Path_flag'] = args.LCM_Path_flag
186
+ args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
187
+ args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
188
+ args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
189
+ args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
190
+ logger.info("Saving arguments to file")
191
+ save_args(args, args_dict['full_model_name'])
192
+ logger.info("Creating Alphabets")
193
+ alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict'])
194
+ args_dict = {**args_dict, **alphabet_dict}
195
+ ARGS = namedtuple('ARGS', args_dict.keys())
196
+ my_args = ARGS(**args_dict)
197
+ return my_args
198
+
199
+
200
+ def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict):
201
+ train_paths = alphabet_data_paths['train']
202
+ extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train']
203
+ alphabet_dict = {}
204
+ alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path,
205
+ train_paths,
206
+ extra_paths=extra_paths,
207
+ max_vocabulary_size=100000,
208
+ embedd_dict=word_dict)
209
+ for k, v in alphabet_dict['alphabets'].items():
210
+ num_key = 'num_' + k.split('_')[0]
211
+ alphabet_dict[num_key] = v.size()
212
+ logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
213
+ return alphabet_dict
214
+
215
+ def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
216
+ if tokens_dict is None:
217
+ return None
218
+ scale = np.sqrt(3.0 / dim)
219
+ table = np.empty([alphabet.size(), dim], dtype=np.float32)
220
+ table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
221
+ oov_tokens = 0
222
+ for token, index in alphabet.items():
223
+ if token in tokens_dict:
224
+ embedding = tokens_dict[token]
225
+ elif token.lower() in tokens_dict:
226
+ embedding = tokens_dict[token.lower()]
227
+ else:
228
+ embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
229
+ oov_tokens += 1
230
+ table[index, :] = embedding
231
+ print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
232
+ table = torch.from_numpy(table)
233
+ return table
234
+
235
+ def save_args(args, full_model_name):
236
+ arg_path = full_model_name + '.arg.json'
237
+ argparse_dict = vars(args)
238
+ with open(arg_path, 'w') as f:
239
+ json.dump(argparse_dict, f)
240
+
241
+ def generate_optimizer(args, lr, params):
242
+ params = filter(lambda param: param.requires_grad, params)
243
+ if args.opt == 'adam':
244
+ return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
245
+ elif args.opt == 'sgd':
246
+ return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
247
+ else:
248
+ raise ValueError('Unknown optimization algorithm: %s' % args.opt)
249
+
250
+
251
+ def save_checkpoint(args, model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
252
+ path_name = full_model_name + '.pt'
253
+ print('Saving model to: %s' % path_name)
254
+ state = {'model_state_dict': model.state_dict(),
255
+ 'optimizer_state_dict': optimizer.state_dict(),
256
+ 'opt': opt,
257
+ 'dev_eval_dict': dev_eval_dict,
258
+ 'test_eval_dict': test_eval_dict}
259
+ torch.save(state, path_name)
260
+
261
+ def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
262
+ print('Loading saved model from: %s' % load_path)
263
+ checkpoint = torch.load(load_path, map_location=args.device)
264
+ if checkpoint['opt'] != args.opt:
265
+ raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
266
+ model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
267
+
268
+
269
+ if strict:
270
+ generate_optimizer(args, args.learning_rate, model.parameters())
271
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
272
+ for state in optimizer.state.values():
273
+ for k, v in state.items():
274
+ if isinstance(v, torch.Tensor):
275
+ state[k] = v.to(args.device)
276
+ dev_eval_dict = checkpoint['dev_eval_dict']
277
+ test_eval_dict = checkpoint['test_eval_dict']
278
+ start_epoch = dev_eval_dict['in_domain']['epoch']
279
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
280
+
281
+
282
+ def build_model_and_optimizer(args):
283
+ word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
284
+ char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
285
+ pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
286
+ model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char,
287
+ args.use_pos, args.use_char, args.pos_dim, args.num_pos,
288
+ args.num_filters, args.kernel_size, args.rnn_mode,
289
+ args.hidden_size, args.num_layers, args.num_arc,
290
+ args.arc_space, args.arc_tag_space, args.num_gates,
291
+ embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
292
+ p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
293
+ biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer)
294
+ print(model)
295
+ optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
296
+ start_epoch = 0
297
+ dev_eval_dict = {'in_domain': initialize_eval_dict()}
298
+ test_eval_dict = {'in_domain': initialize_eval_dict()}
299
+ if args.load_path:
300
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
301
+ load_checkpoint(args, model, optimizer,
302
+ dev_eval_dict, test_eval_dict,
303
+ start_epoch, args.load_path, strict=args.strict)
304
+ if args.load_sequence_taggers_paths:
305
+ pretrained_dict = {}
306
+ model_dict = model.state_dict()
307
+ for idx, path in enumerate(args.load_sequence_taggers_paths):
308
+ print('Loading saved sequence_tagger from: %s' % path)
309
+ checkpoint = torch.load(path, map_location=args.device)
310
+ for k, v in checkpoint['model_state_dict'].items():
311
+ if 'rnn_encoder.' in k:
312
+ pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v
313
+ model_dict.update(pretrained_dict)
314
+ model.load_state_dict(model_dict)
315
+ if args.freeze_sequence_taggers:
316
+ print('Freezing Classifiers')
317
+ for name, parameter in model.named_parameters():
318
+ if 'extra_rnn_encoders' in name:
319
+ parameter.requires_grad = False
320
+ if args.freeze_word_embeddings:
321
+ model.rnn_encoder.word_embedd.weight.requires_grad = False
322
+ # model.rnn_encoder.char_embedd.weight.requires_grad = False
323
+ # model.rnn_encoder.pos_embedd.weight.requires_grad = False
324
+ device = args.device
325
+ model.to(device)
326
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
327
+
328
+
329
+ def initialize_eval_dict():
330
+ eval_dict = {}
331
+ eval_dict['dp_uas'] = 0.0
332
+ eval_dict['dp_las'] = 0.0
333
+ eval_dict['epoch'] = 0
334
+ eval_dict['dp_ucorrect'] = 0.0
335
+ eval_dict['dp_lcorrect'] = 0.0
336
+ eval_dict['dp_total'] = 0.0
337
+ eval_dict['dp_ucomplete_match'] = 0.0
338
+ eval_dict['dp_lcomplete_match'] = 0.0
339
+ eval_dict['dp_ucorrect_nopunc'] = 0.0
340
+ eval_dict['dp_lcorrect_nopunc'] = 0.0
341
+ eval_dict['dp_total_nopunc'] = 0.0
342
+ eval_dict['dp_ucomplete_match_nopunc'] = 0.0
343
+ eval_dict['dp_lcomplete_match_nopunc'] = 0.0
344
+ eval_dict['dp_root_correct'] = 0.0
345
+ eval_dict['dp_total_root'] = 0.0
346
+ eval_dict['dp_total_inst'] = 0.0
347
+ eval_dict['dp_total'] = 0.0
348
+ eval_dict['dp_total_inst'] = 0.0
349
+ eval_dict['dp_total_nopunc'] = 0.0
350
+ eval_dict['dp_total_root'] = 0.0
351
+ return eval_dict
352
+
353
+ def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
354
+ best_model, best_optimizer, patient):
355
+ # In-domain evaluation
356
+ curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
357
+ is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \
358
+ (dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and
359
+ dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc'])
360
+
361
+ if is_best_in_domain:
362
+ for key, value in curr_dev_eval_dict.items():
363
+ dev_eval_dict['in_domain'][key] = value
364
+ curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
365
+ for key, value in curr_test_eval_dict.items():
366
+ test_eval_dict['in_domain'][key] = value
367
+ best_model = deepcopy(model)
368
+ best_optimizer = deepcopy(optimizer)
369
+ patient = 0
370
+ else:
371
+ patient += 1
372
+ if epoch == args.num_epochs:
373
+ # save in-domain checkpoint
374
+ if args.set_num_training_samples is not None:
375
+ splits_to_write = datasets.keys()
376
+ else:
377
+ splits_to_write = ['dev', 'test']
378
+ for split in splits_to_write:
379
+ if split == 'dev':
380
+ eval_dict = dev_eval_dict['in_domain']
381
+ elif split == 'test':
382
+ eval_dict = test_eval_dict['in_domain']
383
+ else:
384
+ eval_dict = None
385
+ write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
386
+ print("Saving best model")
387
+ save_checkpoint(args, best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
388
+
389
+ print('\n')
390
+ return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient
391
+
392
+
393
+ def evaluation(args, data, split, model, domain, epoch, str_res='results'):
394
+ # evaluate performance on data
395
+ model.eval()
396
+
397
+ eval_dict = initialize_eval_dict()
398
+ eval_dict['epoch'] = epoch
399
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
400
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
401
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
402
+ heads_pred, arc_tags_pred, _ = model.decode(args.model_path, word, pos,ner, out_arc, out_arc_tag, mask=masks, length=lengths,
403
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
404
+ lengths = lengths.cpu().numpy()
405
+ word = word.data.cpu().numpy()
406
+ pos = pos.data.cpu().numpy()
407
+ ner = ner.data.cpu().numpy()
408
+ heads = heads.data.cpu().numpy()
409
+ arc_tags = arc_tags.data.cpu().numpy()
410
+ heads_pred = heads_pred.data.cpu().numpy()
411
+ arc_tags_pred = arc_tags_pred.data.cpu().numpy()
412
+ stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
413
+ arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
414
+ lengths, punct_set=args.punct_set, symbolic_root=True)
415
+ ucorr, lcorr, total, ucm, lcm = stats
416
+ ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
417
+ corr_root, total_root = stats_root
418
+ eval_dict['dp_ucorrect'] += ucorr
419
+ eval_dict['dp_lcorrect'] += lcorr
420
+ eval_dict['dp_total'] += total
421
+ eval_dict['dp_ucomplete_match'] += ucm
422
+ eval_dict['dp_lcomplete_match'] += lcm
423
+ eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
424
+ eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
425
+ eval_dict['dp_total_nopunc'] += total_nopunc
426
+ eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
427
+ eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
428
+ eval_dict['dp_root_correct'] += corr_root
429
+ eval_dict['dp_total_root'] += total_root
430
+ eval_dict['dp_total_inst'] += num_inst
431
+
432
+ eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
433
+ eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
434
+ print_results(eval_dict, split, domain, str_res)
435
+ return eval_dict
436
+
437
+ def constrained_evaluation(args, data, split, model, domain, epoch, str_res='results'):
438
+ # evaluate performance on data
439
+ model.eval()
440
+
441
+ eval_dict = initialize_eval_dict()
442
+ eval_dict['epoch'] = epoch
443
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
444
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
445
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
446
+ heads_pred, arc_tags_pred, _ = model.constrained_decode(args, word, pos,ner, out_arc, out_arc_tag, mask=masks, length=lengths,
447
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
448
+ lengths = lengths.cpu().numpy()
449
+ word = word.data.cpu().numpy()
450
+ pos = pos.data.cpu().numpy()
451
+ ner = ner.data.cpu().numpy()
452
+ heads = heads.data.cpu().numpy()
453
+ arc_tags = arc_tags.data.cpu().numpy()
454
+ heads_pred = heads_pred.data.cpu().numpy()
455
+ arc_tags_pred = arc_tags_pred.data.cpu().numpy()
456
+ stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
457
+ arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
458
+ lengths, punct_set=args.punct_set, symbolic_root=True)
459
+ ucorr, lcorr, total, ucm, lcm = stats
460
+ ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
461
+ corr_root, total_root = stats_root
462
+ eval_dict['dp_ucorrect'] += ucorr
463
+ eval_dict['dp_lcorrect'] += lcorr
464
+ eval_dict['dp_total'] += total
465
+ eval_dict['dp_ucomplete_match'] += ucm
466
+ eval_dict['dp_lcomplete_match'] += lcm
467
+ eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
468
+ eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
469
+ eval_dict['dp_total_nopunc'] += total_nopunc
470
+ eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
471
+ eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
472
+ eval_dict['dp_root_correct'] += corr_root
473
+ eval_dict['dp_total_root'] += total_root
474
+ eval_dict['dp_total_inst'] += num_inst
475
+
476
+ eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
477
+ eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
478
+ print_results(eval_dict, split, domain, str_res)
479
+ return eval_dict
480
+
481
+ def print_results(eval_dict, split, domain, str_res='results'):
482
+ print('----------------------------------------------------------------------------------------------------------------------------')
483
+ print('Testing model on domain %s' % domain)
484
+ print('--------------- Dependency Parsing - %s ---------------' % split)
485
+ print(
486
+ str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
487
+ eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'],
488
+ eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'],
489
+ eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'],
490
+ eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'],
491
+ eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'],
492
+ eval_dict['epoch']))
493
+ print(
494
+ str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
495
+ eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'],
496
+ eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
497
+ eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
498
+ eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
499
+ eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
500
+ eval_dict['epoch']))
501
+ print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
502
+ eval_dict['dp_root_correct'], eval_dict['dp_total_root'],
503
+ eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch']))
504
+ print('\n')
505
+ def constrained_write_results(args, data, data_domain, split, model, model_domain, eval_dict):
506
+ str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
507
+ res_filename = str_file + '_res.txt'
508
+ pred_filename = str_file + '_pred.txt'
509
+ gold_filename = str_file + '_gold.txt'
510
+ if eval_dict is not None:
511
+ # save results dictionary into a file
512
+ with open(res_filename, 'w') as f:
513
+ json.dump(eval_dict, f)
514
+
515
+ # save predictions and gold labels into files
516
+ pred_writer = Writer(args.alphabets)
517
+ gold_writer = Writer(args.alphabets)
518
+ pred_writer.start(pred_filename)
519
+ gold_writer.start(gold_filename)
520
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
521
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
522
+ # pdb.set_trace()
523
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
524
+ heads_pred, arc_tags_pred, _ = model.constrained_decode(args,word, pos, ner, out_arc, out_arc_tag, mask=masks, length=lengths,
525
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
526
+ lengths = lengths.cpu().numpy()
527
+ word = word.data.cpu().numpy()
528
+ pos = pos.data.cpu().numpy()
529
+ ner = ner.data.cpu().numpy()
530
+ heads = heads.data.cpu().numpy()
531
+ arc_tags = arc_tags.data.cpu().numpy()
532
+ heads_pred = heads_pred.data.cpu().numpy()
533
+ arc_tags_pred = arc_tags_pred.data.cpu().numpy()
534
+ # print('words',word)
535
+ # print('Pos',pos)
536
+
537
+ # print('heads_pred',heads_pred)
538
+ # print('arc_tags_pred',arc_tags_pred)
539
+ # pdb.set_trace()
540
+ # writing predictions
541
+ pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
542
+ # writing gold labels
543
+ gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
544
+
545
+ pred_writer.close()
546
+ gold_writer.close()
547
+ def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
548
+ str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
549
+ res_filename = str_file + '_res.txt'
550
+ pred_filename = str_file + '_pred.txt'
551
+ gold_filename = str_file + '_gold.txt'
552
+ if eval_dict is not None:
553
+ # save results dictionary into a file
554
+ with open(res_filename, 'w') as f:
555
+ json.dump(eval_dict, f)
556
+
557
+ # save predictions and gold labels into files
558
+ pred_writer = Writer(args.alphabets)
559
+ gold_writer = Writer(args.alphabets)
560
+ pred_writer.start(pred_filename)
561
+ gold_writer.start(gold_filename)
562
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
563
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
564
+ # pdb.set_trace()
565
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
566
+ heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos, ner, out_arc, out_arc_tag, mask=masks, length=lengths,
567
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
568
+ lengths = lengths.cpu().numpy()
569
+ word = word.data.cpu().numpy()
570
+ pos = pos.data.cpu().numpy()
571
+ ner = ner.data.cpu().numpy()
572
+ heads = heads.data.cpu().numpy()
573
+ arc_tags = arc_tags.data.cpu().numpy()
574
+ heads_pred = heads_pred.data.cpu().numpy()
575
+ arc_tags_pred = arc_tags_pred.data.cpu().numpy()
576
+ # print('words',word)
577
+ # print('Pos',pos)
578
+
579
+ # print('heads_pred',heads_pred)
580
+ # print('arc_tags_pred',arc_tags_pred)
581
+ # pdb.set_trace()
582
+ # writing predictions
583
+ pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
584
+ # writing gold labels
585
+ gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
586
+
587
+ pred_writer.close()
588
+ gold_writer.close()
589
+
590
+ def main():
591
+ logger.info("Reading and creating arguments")
592
+ args = read_arguments()
593
+ logger.info("Reading Data")
594
+ datasets = {}
595
+ for split in args.splits:
596
+ print("Splits are:",split)
597
+ dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
598
+ symbolic_root=True)
599
+ datasets[split] = dataset
600
+ if args.set_num_training_samples is not None:
601
+ print('Setting train and dev to %d samples' % args.set_num_training_samples)
602
+ datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples)
603
+ logger.info("Creating Networks")
604
+ num_data = sum(datasets['train'][1])
605
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
606
+ best_model = deepcopy(model)
607
+ best_optimizer = deepcopy(optimizer)
608
+
609
+ logger.info('Training INFO of in domain %s' % args.domain)
610
+ logger.info('Training on Dependecy Parsing')
611
+ logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
612
+ logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
613
+ logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
614
+ logger.info("num_epochs: %d" % (args.num_epochs))
615
+ print('\n')
616
+
617
+ if not args.eval_mode:
618
+ logger.info("Training")
619
+ num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
620
+ lr = args.learning_rate
621
+ patient = 0
622
+ decay = 0
623
+ for epoch in range(start_epoch + 1, args.num_epochs + 1):
624
+ print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
625
+ epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
626
+ model.train()
627
+ total_loss = 0.0
628
+ total_arc_loss = 0.0
629
+ total_arc_tag_loss = 0.0
630
+ total_train_inst = 0.0
631
+
632
+ train_iter = prepare_data.iterate_batch_rand_bucket_choosing(
633
+ datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
634
+ start_time = time.time()
635
+ batch_num = 0
636
+ for batch_num, batch in enumerate(train_iter):
637
+ batch_num = batch_num + 1
638
+ optimizer.zero_grad()
639
+ # compute loss of main task
640
+ word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
641
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
642
+ loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths)
643
+ loss = loss_arc + loss_arc_tag
644
+ # pdb.set_trace()
645
+
646
+ # update losses
647
+ num_insts = masks.data.sum() - word.size(0)
648
+ total_arc_loss += loss_arc.item() * num_insts
649
+ total_arc_tag_loss += loss_arc_tag.item() * num_insts
650
+ total_loss += loss.item() * num_insts
651
+ total_train_inst += num_insts
652
+ # optimize parameters
653
+ loss.backward()
654
+ clip_grad_norm_(model.parameters(), args.clip)
655
+ optimizer.step()
656
+
657
+ time_ave = (time.time() - start_time) / batch_num
658
+ time_left = (num_batches - batch_num) * time_ave
659
+
660
+ # update log
661
+ if batch_num % 50 == 0:
662
+ log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time left: %.2fs' % \
663
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
664
+ total_arc_tag_loss / total_train_inst, time_left)
665
+ sys.stdout.write(log_info)
666
+ sys.stdout.write('\n')
667
+ sys.stdout.flush()
668
+ print('\n')
669
+ print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time: %.2fs' %
670
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
671
+ total_arc_tag_loss / total_train_inst, time.time() - start_time))
672
+
673
+ dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
674
+ if patient >= args.schedule:
675
+ lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
676
+ optimizer = generate_optimizer(args, lr, model.parameters())
677
+ print('updated learning rate to %.6f' % lr)
678
+ patient = 0
679
+ print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
680
+ print('\n')
681
+ for split in datasets.keys():
682
+ if args.eval_with_CI and split not in ['train', 'extra_train', 'extra_dev']:
683
+ print('Currently going on ... ',split)
684
+ eval_dict = constrained_evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
685
+ constrained_write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
686
+ else:
687
+ eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
688
+ write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
689
+
690
+ else:
691
+ logger.info("Evaluating")
692
+ epoch = start_epoch
693
+ for split in ['train','dev','test']:
694
+ if args.eval_with_CI and split != 'train':
695
+ eval_dict = constrained_evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
696
+ constrained_write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
697
+ else:
698
+ eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
699
+ write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
700
+
701
+
702
+ if __name__ == '__main__':
703
+ main()
examples/GraphParser_MRL.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import sys
3
+ from os import path, makedirs
4
+
5
+ sys.path.append(".")
6
+ sys.path.append("..")
7
+
8
+ import argparse
9
+ from copy import deepcopy
10
+ import json
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from collections import namedtuple
15
+ from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits
16
+ from utils.models.parsing_gating import BiAffine_Parser_Gated
17
+ from utils import load_word_embeddings
18
+ from utils.tasks import parse
19
+ import time
20
+ from torch.nn.utils import clip_grad_norm_
21
+ from torch.optim import Adam, SGD
22
+ import uuid
23
+
24
+ uid = uuid.uuid4().hex[:6]
25
+
26
+ logger = get_logger('GraphParser')
27
+
28
+ def read_arguments():
29
+ args_ = argparse.ArgumentParser(description='Sovling GraphParser')
30
+ args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
31
+ args_.add_argument('--domain', help='domain/language', required=True)
32
+ args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
33
+ required=True)
34
+ args_.add_argument('--gating',action='store_true', help='use gated mechanism')
35
+ args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism')
36
+ args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
37
+ args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
38
+ args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
39
+ args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
40
+ args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space')
41
+ args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
42
+ args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
43
+ args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
44
+ args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
45
+ args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
46
+ args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
47
+ args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
48
+ args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
49
+ args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
50
+ args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
51
+ args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
52
+ args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
53
+ args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
54
+ args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
55
+ args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
56
+ args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
57
+ args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
58
+ args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
59
+ args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
60
+ args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
61
+ args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
62
+ args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True)
63
+ args_.add_argument('--unk_replace', type=float, default=0.,
64
+ help='The rate to replace a singleton word with UNK')
65
+ args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
66
+ args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
67
+ help='Embedding for words')
68
+ args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
69
+ args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
70
+ args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.')
71
+ args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
72
+ required=True)
73
+ args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
74
+ required=True)
75
+ args_.add_argument('--char_path', help='path for character embedding dict')
76
+ args_.add_argument('--pos_path', help='path for pos embedding dict')
77
+ args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples')
78
+ args_.add_argument('--model_path', help='path for saving model file.', required=True)
79
+ args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
80
+ args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None)
81
+ args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin '
82
+ 'exactly the same keys as current model')
83
+ args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
84
+ args = args_.parse_args()
85
+ args_dict = {}
86
+ args_dict['dataset'] = args.dataset
87
+ args_dict['domain'] = args.domain
88
+ args_dict['rnn_mode'] = args.rnn_mode
89
+ args_dict['gating'] = args.gating
90
+ args_dict['num_gates'] = args.num_gates
91
+ args_dict['arc_decode'] = args.arc_decode
92
+ # args_dict['splits'] = ['train', 'dev', 'test']
93
+ args_dict['splits'] = ['train', 'dev', 'test','poetry','prose']
94
+ args_dict['model_path'] = args.model_path
95
+ if not path.exists(args_dict['model_path']):
96
+ makedirs(args_dict['model_path'])
97
+ args_dict['data_paths'] = {}
98
+ if args_dict['dataset'] == 'ontonotes':
99
+ data_path = 'data/Pre_MRL/onto_pos_ner_dp'
100
+ else:
101
+ data_path = 'data/Prep_MRL/ud_pos_ner_dp'
102
+ for split in args_dict['splits']:
103
+ args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain']
104
+ ###################################
105
+ args_dict['data_paths']['poetry'] = data_path + '_' + 'test' + '_' + args_dict['domain']
106
+ args_dict['data_paths']['prose'] = data_path + '_' + 'test' + '_' + args_dict['domain']
107
+ ###################################
108
+ args_dict['alphabet_data_paths'] = {}
109
+ for split in args_dict['splits']:
110
+ if args_dict['dataset'] == 'ontonotes':
111
+ args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all'
112
+ else:
113
+ if '_' in args_dict['domain']:
114
+ args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
115
+ else:
116
+ args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split]
117
+ args_dict['model_name'] = 'domain_' + args_dict['domain']
118
+ args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
119
+ args_dict['load_path'] = args.load_path
120
+ args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths
121
+ if args_dict['load_sequence_taggers_paths'] is not None:
122
+ args_dict['gating'] = True
123
+ args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1
124
+ else:
125
+ if not args_dict['gating']:
126
+ args_dict['num_gates'] = 0
127
+ args_dict['strict'] = args.strict
128
+ args_dict['num_epochs'] = args.num_epochs
129
+ args_dict['batch_size'] = args.batch_size
130
+ args_dict['hidden_size'] = args.hidden_size
131
+ args_dict['arc_space'] = args.arc_space
132
+ args_dict['arc_tag_space'] = args.arc_tag_space
133
+ args_dict['num_layers'] = args.num_layers
134
+ args_dict['num_filters'] = args.num_filters
135
+ args_dict['kernel_size'] = args.kernel_size
136
+ args_dict['learning_rate'] = args.learning_rate
137
+ args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
138
+ args_dict['opt'] = args.opt
139
+ args_dict['momentum'] = args.momentum
140
+ args_dict['betas'] = tuple(args.betas)
141
+ args_dict['epsilon'] = args.epsilon
142
+ args_dict['decay_rate'] = args.decay_rate
143
+ args_dict['clip'] = args.clip
144
+ args_dict['gamma'] = args.gamma
145
+ args_dict['schedule'] = args.schedule
146
+ args_dict['p_rnn'] = tuple(args.p_rnn)
147
+ args_dict['p_in'] = args.p_in
148
+ args_dict['p_out'] = args.p_out
149
+ args_dict['unk_replace'] = args.unk_replace
150
+ args_dict['set_num_training_samples'] = args.set_num_training_samples
151
+ args_dict['punct_set'] = None
152
+ if args.punct_set is not None:
153
+ args_dict['punct_set'] = set(args.punct_set)
154
+ logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
155
+ args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
156
+ args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers
157
+ args_dict['word_embedding'] = args.word_embedding
158
+ args_dict['word_path'] = args.word_path
159
+ args_dict['use_char'] = args.use_char
160
+ args_dict['char_embedding'] = args.char_embedding
161
+ args_dict['char_path'] = args.char_path
162
+ args_dict['pos_embedding'] = args.pos_embedding
163
+ args_dict['pos_path'] = args.pos_path
164
+ args_dict['use_pos'] = args.use_pos
165
+ args_dict['pos_dim'] = args.pos_dim
166
+ args_dict['word_dict'] = None
167
+ args_dict['word_dim'] = args.word_dim
168
+ if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
169
+ args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
170
+ args_dict['word_path'])
171
+ args_dict['char_dict'] = None
172
+ args_dict['char_dim'] = args.char_dim
173
+ if args_dict['char_embedding'] != 'random':
174
+ args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
175
+ args_dict['char_path'])
176
+ args_dict['pos_dict'] = None
177
+ if args_dict['pos_embedding'] != 'random':
178
+ args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
179
+ args_dict['pos_path'])
180
+ args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
181
+ args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
182
+ args_dict['eval_mode'] = args.eval_mode
183
+ args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
184
+ args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
185
+ args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
186
+ args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
187
+ logger.info("Saving arguments to file")
188
+ save_args(args, args_dict['full_model_name'])
189
+ logger.info("Creating Alphabets")
190
+ alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict'])
191
+ args_dict = {**args_dict, **alphabet_dict}
192
+ ARGS = namedtuple('ARGS', args_dict.keys())
193
+ my_args = ARGS(**args_dict)
194
+ return my_args
195
+
196
+
197
+ def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict):
198
+ train_paths = alphabet_data_paths['train']
199
+ extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train']
200
+ alphabet_dict = {}
201
+ alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path,
202
+ train_paths,
203
+ extra_paths=extra_paths,
204
+ max_vocabulary_size=100000,
205
+ embedd_dict=word_dict)
206
+ for k, v in alphabet_dict['alphabets'].items():
207
+ num_key = 'num_' + k.split('_')[0]
208
+ alphabet_dict[num_key] = v.size()
209
+ logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
210
+ return alphabet_dict
211
+
212
+ def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
213
+ if tokens_dict is None:
214
+ return None
215
+ scale = np.sqrt(3.0 / dim)
216
+ table = np.empty([alphabet.size(), dim], dtype=np.float32)
217
+ table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
218
+ oov_tokens = 0
219
+ for token, index in alphabet.items():
220
+ if token in ['aTA']:
221
+ embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
222
+ oov_tokens += 1
223
+ elif token in tokens_dict:
224
+ embedding = tokens_dict[token]
225
+ elif token.lower() in tokens_dict:
226
+ embedding = tokens_dict[token.lower()]
227
+ else:
228
+ embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
229
+ oov_tokens += 1
230
+ # print(token)
231
+ table[index, :] = embedding
232
+ print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
233
+ table = torch.from_numpy(table)
234
+ return table
235
+
236
+ def save_args(args, full_model_name):
237
+ arg_path = full_model_name + '.arg.json'
238
+ argparse_dict = vars(args)
239
+ with open(arg_path, 'w') as f:
240
+ json.dump(argparse_dict, f)
241
+
242
+ def generate_optimizer(args, lr, params):
243
+ params = filter(lambda param: param.requires_grad, params)
244
+ if args.opt == 'adam':
245
+ return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
246
+ elif args.opt == 'sgd':
247
+ return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
248
+ else:
249
+ raise ValueError('Unknown optimization algorithm: %s' % args.opt)
250
+
251
+
252
+ def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
253
+ path_name = full_model_name + '.pt'
254
+ print('Saving model to: %s' % path_name)
255
+ state = {'model_state_dict': model.state_dict(),
256
+ 'optimizer_state_dict': optimizer.state_dict(),
257
+ 'opt': opt,
258
+ 'dev_eval_dict': dev_eval_dict,
259
+ 'test_eval_dict': test_eval_dict}
260
+ torch.save(state, path_name)
261
+
262
+
263
+ def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
264
+ print('Loading saved model from: %s' % load_path)
265
+ checkpoint = torch.load(load_path, map_location=args.device)
266
+ if checkpoint['opt'] != args.opt:
267
+ raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
268
+ model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
269
+
270
+ if strict:
271
+ generate_optimizer(args, args.learning_rate, model.parameters())
272
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
273
+ for state in optimizer.state.values():
274
+ for k, v in state.items():
275
+ if isinstance(v, torch.Tensor):
276
+ state[k] = v.to(args.device)
277
+ dev_eval_dict = checkpoint['dev_eval_dict']
278
+ test_eval_dict = checkpoint['test_eval_dict']
279
+ start_epoch = dev_eval_dict['in_domain']['epoch']
280
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
281
+
282
+
283
+ def build_model_and_optimizer(args):
284
+ word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
285
+ char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
286
+ pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
287
+ model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char,
288
+ args.use_pos, args.use_char, args.pos_dim, args.num_pos,
289
+ args.num_filters, args.kernel_size, args.rnn_mode,
290
+ args.hidden_size, args.num_layers, args.num_arc,
291
+ args.arc_space, args.arc_tag_space, args.num_gates,
292
+ embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
293
+ p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
294
+ biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer)
295
+ print(model)
296
+ optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
297
+ start_epoch = 0
298
+ dev_eval_dict = {'in_domain': initialize_eval_dict()}
299
+ test_eval_dict = {'in_domain': initialize_eval_dict()}
300
+ if args.load_path:
301
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
302
+ load_checkpoint(args, model, optimizer,
303
+ dev_eval_dict, test_eval_dict,
304
+ start_epoch, args.load_path, strict=args.strict)
305
+ if args.load_sequence_taggers_paths:
306
+ pretrained_dict = {}
307
+ model_dict = model.state_dict()
308
+ for idx, path in enumerate(args.load_sequence_taggers_paths):
309
+ print('Loading saved sequence_tagger from: %s' % path)
310
+ checkpoint = torch.load(path, map_location=args.device)
311
+ for k, v in checkpoint['model_state_dict'].items():
312
+ if 'rnn_encoder.' in k:
313
+ pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v
314
+ model_dict.update(pretrained_dict)
315
+ model.load_state_dict(model_dict)
316
+ if args.freeze_sequence_taggers:
317
+ print('Freezing Classifiers')
318
+ for name, parameter in model.named_parameters():
319
+ if 'extra_rnn_encoders' in name:
320
+ parameter.requires_grad = False
321
+ if args.freeze_word_embeddings:
322
+ model.rnn_encoder.word_embedd.weight.requires_grad = False
323
+ # model.rnn_encoder.char_embedd.weight.requires_grad = False
324
+ # model.rnn_encoder.pos_embedd.weight.requires_grad = False
325
+ device = args.device
326
+ model.to(device)
327
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
328
+
329
+
330
+ def initialize_eval_dict():
331
+ eval_dict = {}
332
+ eval_dict['dp_uas'] = 0.0
333
+ eval_dict['dp_las'] = 0.0
334
+ eval_dict['epoch'] = 0
335
+ eval_dict['dp_ucorrect'] = 0.0
336
+ eval_dict['dp_lcorrect'] = 0.0
337
+ eval_dict['dp_total'] = 0.0
338
+ eval_dict['dp_ucomplete_match'] = 0.0
339
+ eval_dict['dp_lcomplete_match'] = 0.0
340
+ eval_dict['dp_ucorrect_nopunc'] = 0.0
341
+ eval_dict['dp_lcorrect_nopunc'] = 0.0
342
+ eval_dict['dp_total_nopunc'] = 0.0
343
+ eval_dict['dp_ucomplete_match_nopunc'] = 0.0
344
+ eval_dict['dp_lcomplete_match_nopunc'] = 0.0
345
+ eval_dict['dp_root_correct'] = 0.0
346
+ eval_dict['dp_total_root'] = 0.0
347
+ eval_dict['dp_total_inst'] = 0.0
348
+ eval_dict['dp_total'] = 0.0
349
+ eval_dict['dp_total_inst'] = 0.0
350
+ eval_dict['dp_total_nopunc'] = 0.0
351
+ eval_dict['dp_total_root'] = 0.0
352
+ return eval_dict
353
+
354
+ def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
355
+ best_model, best_optimizer, patient):
356
+ # In-domain evaluation
357
+ curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
358
+ is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \
359
+ (dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and
360
+ dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc'])
361
+
362
+ if is_best_in_domain:
363
+ for key, value in curr_dev_eval_dict.items():
364
+ dev_eval_dict['in_domain'][key] = value
365
+ curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
366
+ for key, value in curr_test_eval_dict.items():
367
+ test_eval_dict['in_domain'][key] = value
368
+ best_model = deepcopy(model)
369
+ best_optimizer = deepcopy(optimizer)
370
+ patient = 0
371
+ else:
372
+ patient += 1
373
+ if epoch == args.num_epochs:
374
+ # save in-domain checkpoint
375
+ if args.set_num_training_samples is not None:
376
+ splits_to_write = datasets.keys()
377
+ else:
378
+ splits_to_write = ['dev', 'test']
379
+ for split in splits_to_write:
380
+ if split == 'dev':
381
+ eval_dict = dev_eval_dict['in_domain']
382
+ elif split == 'test':
383
+ eval_dict = test_eval_dict['in_domain']
384
+ else:
385
+ eval_dict = None
386
+ write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
387
+ print("Saving best model")
388
+ save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
389
+
390
+ print('\n')
391
+ return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient
392
+
393
+
394
+ def evaluation(args, data, split, model, domain, epoch, str_res='results'):
395
+ # evaluate performance on data
396
+ model.eval()
397
+
398
+ eval_dict = initialize_eval_dict()
399
+ eval_dict['epoch'] = epoch
400
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
401
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
402
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
403
+ heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos, ner,out_arc, out_arc_tag, mask=masks, length=lengths,
404
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
405
+ lengths = lengths.cpu().numpy()
406
+ word = word.data.cpu().numpy()
407
+ pos = pos.data.cpu().numpy()
408
+ ner = ner.data.cpu().numpy()
409
+ heads = heads.data.cpu().numpy()
410
+ arc_tags = arc_tags.data.cpu().numpy()
411
+ heads_pred = heads_pred.data.cpu().numpy()
412
+ arc_tags_pred = arc_tags_pred.data.cpu().numpy()
413
+ stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
414
+ arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
415
+ lengths, punct_set=args.punct_set, symbolic_root=True)
416
+ ucorr, lcorr, total, ucm, lcm = stats
417
+ ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
418
+ corr_root, total_root = stats_root
419
+ eval_dict['dp_ucorrect'] += ucorr
420
+ eval_dict['dp_lcorrect'] += lcorr
421
+ eval_dict['dp_total'] += total
422
+ eval_dict['dp_ucomplete_match'] += ucm
423
+ eval_dict['dp_lcomplete_match'] += lcm
424
+ eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
425
+ eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
426
+ eval_dict['dp_total_nopunc'] += total_nopunc
427
+ eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
428
+ eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
429
+ eval_dict['dp_root_correct'] += corr_root
430
+ eval_dict['dp_total_root'] += total_root
431
+ eval_dict['dp_total_inst'] += num_inst
432
+
433
+ eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
434
+ eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
435
+ print_results(eval_dict, split, domain, str_res)
436
+ return eval_dict
437
+
438
+
439
+ def print_results(eval_dict, split, domain, str_res='results'):
440
+ print('----------------------------------------------------------------------------------------------------------------------------')
441
+ print('Testing model on domain %s' % domain)
442
+ print('--------------- Dependency Parsing - %s ---------------' % split)
443
+ print(
444
+ str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
445
+ eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'],
446
+ eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'],
447
+ eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'],
448
+ eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'],
449
+ eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'],
450
+ eval_dict['epoch']))
451
+ print(
452
+ str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
453
+ eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'],
454
+ eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
455
+ eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
456
+ eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
457
+ eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
458
+ eval_dict['epoch']))
459
+ print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
460
+ eval_dict['dp_root_correct'], eval_dict['dp_total_root'],
461
+ eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch']))
462
+ print('\n')
463
+
464
+ def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
465
+ str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
466
+ res_filename = str_file + '_res.txt'
467
+ pred_filename = str_file + '_pred.txt'
468
+ gold_filename = str_file + '_gold.txt'
469
+ if eval_dict is not None:
470
+ # save results dictionary into a file
471
+ with open(res_filename, 'w') as f:
472
+ json.dump(eval_dict, f)
473
+
474
+ # save predictions and gold labels into files
475
+ pred_writer = Writer(args.alphabets)
476
+ gold_writer = Writer(args.alphabets)
477
+ pred_writer.start(pred_filename)
478
+ gold_writer.start(gold_filename)
479
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
480
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
481
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
482
+ heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos,ner,out_arc, out_arc_tag, mask=masks, length=lengths,
483
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
484
+ lengths = lengths.cpu().numpy()
485
+ word = word.data.cpu().numpy()
486
+ pos = pos.data.cpu().numpy()
487
+ ner = ner.data.cpu().numpy()
488
+ heads = heads.data.cpu().numpy()
489
+ arc_tags = arc_tags.data.cpu().numpy()
490
+ heads_pred = heads_pred.data.cpu().numpy()
491
+ arc_tags_pred = arc_tags_pred.data.cpu().numpy()
492
+ # writing predictions
493
+ pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
494
+ # writing gold labels
495
+ gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
496
+
497
+ pred_writer.close()
498
+ gold_writer.close()
499
+
500
+ def main():
501
+ logger.info("Reading and creating arguments")
502
+ args = read_arguments()
503
+ logger.info("Reading Data")
504
+ datasets = {}
505
+ for split in args.splits:
506
+ print("Splits are:",split)
507
+ dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
508
+ symbolic_root=True)
509
+ datasets[split] = dataset
510
+ if args.set_num_training_samples is not None:
511
+ print('Setting train and dev to %d samples' % args.set_num_training_samples)
512
+ datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples)
513
+ logger.info("Creating Networks")
514
+ num_data = sum(datasets['train'][1])
515
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
516
+ best_model = deepcopy(model)
517
+ best_optimizer = deepcopy(optimizer)
518
+
519
+ logger.info('Training INFO of in domain %s' % args.domain)
520
+ logger.info('Training on Dependecy Parsing')
521
+ logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
522
+ logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
523
+ logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
524
+ logger.info("num_epochs: %d" % (args.num_epochs))
525
+ print('\n')
526
+
527
+ if not args.eval_mode:
528
+ logger.info("Training")
529
+ num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
530
+ lr = args.learning_rate
531
+ patient = 0
532
+ decay = 0
533
+ for epoch in range(start_epoch + 1, args.num_epochs + 1):
534
+ print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
535
+ epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
536
+ model.train()
537
+ total_loss = 0.0
538
+ total_arc_loss = 0.0
539
+ total_arc_tag_loss = 0.0
540
+ total_train_inst = 0.0
541
+
542
+ train_iter = prepare_data.iterate_batch_rand_bucket_choosing(
543
+ datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
544
+ start_time = time.time()
545
+ batch_num = 0
546
+ for batch_num, batch in enumerate(train_iter):
547
+ batch_num = batch_num + 1
548
+ optimizer.zero_grad()
549
+ # compute loss of main task
550
+ word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
551
+ out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
552
+ loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths)
553
+ loss = loss_arc + loss_arc_tag
554
+
555
+ # update losses
556
+ num_insts = masks.data.sum() - word.size(0)
557
+ total_arc_loss += loss_arc.item() * num_insts
558
+ total_arc_tag_loss += loss_arc_tag.item() * num_insts
559
+ total_loss += loss.item() * num_insts
560
+ total_train_inst += num_insts
561
+ # optimize parameters
562
+ loss.backward()
563
+ clip_grad_norm_(model.parameters(), args.clip)
564
+ optimizer.step()
565
+
566
+ time_ave = (time.time() - start_time) / batch_num
567
+ time_left = (num_batches - batch_num) * time_ave
568
+
569
+ # update log
570
+ if batch_num % 50 == 0:
571
+ log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time left: %.2fs' % \
572
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
573
+ total_arc_tag_loss / total_train_inst, time_left)
574
+ sys.stdout.write(log_info)
575
+ sys.stdout.write('\n')
576
+ sys.stdout.flush()
577
+ print('\n')
578
+ print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time: %.2fs' %
579
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
580
+ total_arc_tag_loss / total_train_inst, time.time() - start_time))
581
+
582
+ dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
583
+ if patient >= args.schedule:
584
+ lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
585
+ optimizer = generate_optimizer(args, lr, model.parameters())
586
+ print('updated learning rate to %.6f' % lr)
587
+ patient = 0
588
+ print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
589
+ print('\n')
590
+ for split in datasets.keys():
591
+ eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
592
+ write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
593
+
594
+ else:
595
+ logger.info("Evaluating")
596
+ epoch = start_epoch
597
+ for split in ['train', 'dev', 'test','poetry','prose']:
598
+ eval_dict = evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
599
+ write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
600
+
601
+
602
+ if __name__ == '__main__':
603
+ main()
examples/SequenceTagger.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import sys
3
+ from os import path, makedirs, system, remove
4
+
5
+ sys.path.append(".")
6
+ sys.path.append("..")
7
+
8
+ import time
9
+ import argparse
10
+ import uuid
11
+ import json
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from collections import namedtuple
16
+ from copy import deepcopy
17
+ from torch.nn.utils import clip_grad_norm_
18
+ from torch.optim import Adam, SGD
19
+ from utils.io_ import seeds, Writer, get_logger, Index2Instance, prepare_data, write_extra_labels
20
+ from utils.models.sequence_tagger import Sequence_Tagger
21
+ from utils import load_word_embeddings
22
+ from utils.tasks.seqeval import accuracy_score, f1_score, precision_score, recall_score,classification_report
23
+
24
+ uid = uuid.uuid4().hex[:6]
25
+
26
+ logger = get_logger('SequenceTagger')
27
+
28
+ def read_arguments():
29
+ args_ = argparse.ArgumentParser(description='Sovling SequenceTagger')
30
+ args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
31
+ args_.add_argument('--domain', help='domain', required=True)
32
+ args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
33
+ required=True)
34
+ args_.add_argument('--task', default='distance_from_the_root', choices=['distance_from_the_root', 'number_of_children',\
35
+ 'relative_pos_based', 'language_model','add_label','add_head_coarse_pos','Multitask_POS_predict','Multitask_case_predict',\
36
+ 'Multitask_label_predict','Multitask_coarse_predict','MRL_Person','MRL_Gender','MRL_case','MRL_POS','MRL_no','MRL_label',\
37
+ 'predict_coarse_of_modifier','predict_ma_tag_of_modifier','add_head_ma','predict_case_of_modifier'], help='sequence_tagger task')
38
+ args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
39
+ args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
40
+ args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
41
+ args_.add_argument('--tag_space', type=int, default=128, help='Dimension of tag space')
42
+ args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
43
+ args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
44
+ args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
45
+ args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
46
+ args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
47
+ args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
48
+ args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
49
+ args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
50
+ args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
51
+ args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
52
+ args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
53
+ args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
54
+ args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
55
+ args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
56
+ args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
57
+ args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
58
+ args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
59
+ args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
60
+ args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
61
+ args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
62
+ args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
63
+ args_.add_argument('--unk_replace', type=float, default=0.,
64
+ help='The rate to replace a singleton word with UNK')
65
+ args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
66
+ args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
67
+ help='Embedding for words')
68
+ args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
69
+ args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
70
+ args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
71
+ required=True)
72
+ args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
73
+ required=True)
74
+ args_.add_argument('--char_path', help='path for character embedding dict')
75
+ args_.add_argument('--pos_path', help='path for pos embedding dict')
76
+ args_.add_argument('--use_unlabeled_data', action='store_true', help='flag to use unlabeled data.')
77
+ args_.add_argument('--use_labeled_data', action='store_true', help='flag to use labeled data.')
78
+ args_.add_argument('--model_path', help='path for saving model file.', required=True)
79
+ args_.add_argument('--parser_path', help='path for loading parser files.', default=None)
80
+ args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
81
+ args_.add_argument('--strict',action='store_true', help='if True loaded model state should contain '
82
+ 'exactly the same keys as current model')
83
+ args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
84
+ args = args_.parse_args()
85
+ args_dict = {}
86
+ args_dict['dataset'] = args.dataset
87
+ args_dict['domain'] = args.domain
88
+ args_dict['task'] = args.task
89
+ args_dict['rnn_mode'] = args.rnn_mode
90
+ args_dict['load_path'] = args.load_path
91
+ args_dict['strict'] = args.strict
92
+ args_dict['model_path'] = args.model_path
93
+ if not path.exists(args_dict['model_path']):
94
+ makedirs(args_dict['model_path'])
95
+ args_dict['parser_path'] = args.parser_path
96
+ args_dict['model_name'] = 'domain_' + args_dict['domain']
97
+ args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
98
+ args_dict['use_unlabeled_data'] = args.use_unlabeled_data
99
+ args_dict['use_labeled_data'] = args.use_labeled_data
100
+ print(args_dict['parser_path'])
101
+ if args_dict['task'] == 'number_of_children':
102
+ args_dict['data_paths'] = write_extra_labels.add_number_of_children(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
103
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
104
+ use_labeled_data=args_dict['use_labeled_data'])
105
+ elif args_dict['task'] == 'distance_from_the_root':
106
+ args_dict['data_paths'] = write_extra_labels.add_distance_from_the_root(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
107
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
108
+ use_labeled_data=args_dict['use_labeled_data'])
109
+ elif args_dict['task'] == 'Multitask_label_predict':
110
+ args_dict['data_paths'] = write_extra_labels.Multitask_label_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
111
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
112
+ use_labeled_data=args_dict['use_labeled_data'])
113
+ elif args_dict['task'] == 'Multitask_coarse_predict':
114
+ args_dict['data_paths'] = write_extra_labels.Multitask_coarse_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
115
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
116
+ use_labeled_data=args_dict['use_labeled_data'])
117
+ elif args_dict['task'] == 'Multitask_POS_predict':
118
+ args_dict['data_paths'] = write_extra_labels.Multitask_POS_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
119
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
120
+ use_labeled_data=args_dict['use_labeled_data'])
121
+ elif args_dict['task'] == 'relative_pos_based':
122
+ args_dict['data_paths'] = write_extra_labels.add_relative_pos_based(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
123
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
124
+ use_labeled_data=args_dict['use_labeled_data'])
125
+ elif args_dict['task'] == 'add_label':
126
+ args_dict['data_paths'] = write_extra_labels.add_label(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
127
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
128
+ use_labeled_data=args_dict['use_labeled_data'])
129
+ elif args_dict['task'] == 'add_relative_TAG':
130
+ args_dict['data_paths'] = write_extra_labels.add_relative_TAG(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
131
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
132
+ use_labeled_data=args_dict['use_labeled_data'])
133
+ elif args_dict['task'] == 'add_head_coarse_pos':
134
+ args_dict['data_paths'] = write_extra_labels.add_head_coarse_pos(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
135
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
136
+ use_labeled_data=args_dict['use_labeled_data'])
137
+ elif args_dict['task'] == 'predict_ma_tag_of_modifier':
138
+ args_dict['data_paths'] = write_extra_labels.predict_ma_tag_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
139
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
140
+ use_labeled_data=args_dict['use_labeled_data'])
141
+ elif args_dict['task'] == 'Multitask_case_predict':
142
+ args_dict['data_paths'] = write_extra_labels.Multitask_case_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
143
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
144
+ use_labeled_data=args_dict['use_labeled_data'])
145
+ elif args_dict['task'] == 'predict_coarse_of_modifier':
146
+ args_dict['data_paths'] = write_extra_labels.predict_coarse_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
147
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
148
+ use_labeled_data=args_dict['use_labeled_data'])
149
+ elif args_dict['task'] == 'predict_case_of_modifier':
150
+ args_dict['data_paths'] = write_extra_labels.predict_case_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
151
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
152
+ use_labeled_data=args_dict['use_labeled_data'])
153
+ elif args_dict['task'] == 'add_head_ma':
154
+ args_dict['data_paths'] = write_extra_labels.add_head_ma(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
155
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
156
+ use_labeled_data=args_dict['use_labeled_data'])
157
+ elif args_dict['task'] == 'MRL_case':
158
+ args_dict['data_paths'] = write_extra_labels.MRL_case(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
159
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
160
+ use_labeled_data=args_dict['use_labeled_data'])
161
+ elif args_dict['task'] == 'MRL_POS':
162
+ args_dict['data_paths'] = write_extra_labels.MRL_POS(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
163
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
164
+ use_labeled_data=args_dict['use_labeled_data'])
165
+ elif args_dict['task'] == 'MRL_no':
166
+ args_dict['data_paths'] = write_extra_labels.MRL_no(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
167
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
168
+ use_labeled_data=args_dict['use_labeled_data'])
169
+ elif args_dict['task'] == 'MRL_label':
170
+ args_dict['data_paths'] = write_extra_labels.MRL_label(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
171
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
172
+ use_labeled_data=args_dict['use_labeled_data'])
173
+ elif args_dict['task'] == 'MRL_Person':
174
+ args_dict['data_paths'] = write_extra_labels.MRL_Person(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
175
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
176
+ use_labeled_data=args_dict['use_labeled_data'])
177
+ elif args_dict['task'] == 'MRL_Gender':
178
+ args_dict['data_paths'] = write_extra_labels.MRL_Gender(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
179
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
180
+ use_labeled_data=args_dict['use_labeled_data'])
181
+ else: #args_dict['task'] == 'language_model':
182
+ args_dict['data_paths'] = write_extra_labels.add_language_model(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
183
+ use_unlabeled_data=args_dict['use_unlabeled_data'],
184
+ use_labeled_data=args_dict['use_labeled_data'])
185
+ args_dict['splits'] = args_dict['data_paths'].keys()
186
+ alphabet_data_paths = deepcopy(args_dict['data_paths'])
187
+ if args_dict['dataset'] == 'ontonotes':
188
+ data_path = 'data/onto_pos_ner_dp'
189
+ else:
190
+ data_path = 'data/ud_pos_ner_dp'
191
+ # Adding more resources to make sure equal alphabet size for all domains
192
+ for split in args_dict['splits']:
193
+ if args_dict['dataset'] == 'ontonotes':
194
+ alphabet_data_paths['additional_' + split] = data_path + '_' + split + '_' + 'all'
195
+ else:
196
+ if '_' in args_dict['domain']:
197
+ alphabet_data_paths[split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
198
+ else:
199
+ alphabet_data_paths[split] = args_dict['data_paths'][split]
200
+ args_dict['alphabet_data_paths'] = alphabet_data_paths
201
+ args_dict['num_epochs'] = args.num_epochs
202
+ args_dict['batch_size'] = args.batch_size
203
+ args_dict['hidden_size'] = args.hidden_size
204
+ args_dict['tag_space'] = args.tag_space
205
+ args_dict['num_layers'] = args.num_layers
206
+ args_dict['num_filters'] = args.num_filters
207
+ args_dict['kernel_size'] = args.kernel_size
208
+ args_dict['learning_rate'] = args.learning_rate
209
+ args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
210
+ args_dict['opt'] = args.opt
211
+ args_dict['momentum'] = args.momentum
212
+ args_dict['betas'] = tuple(args.betas)
213
+ args_dict['epsilon'] = args.epsilon
214
+ args_dict['decay_rate'] = args.decay_rate
215
+ args_dict['clip'] = args.clip
216
+ args_dict['gamma'] = args.gamma
217
+ args_dict['schedule'] = args.schedule
218
+ args_dict['p_rnn'] = tuple(args.p_rnn)
219
+ args_dict['p_in'] = args.p_in
220
+ args_dict['p_out'] = args.p_out
221
+ args_dict['unk_replace'] = args.unk_replace
222
+ args_dict['punct_set'] = None
223
+ if args.punct_set is not None:
224
+ args_dict['punct_set'] = set(args.punct_set)
225
+ logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
226
+ args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
227
+ args_dict['word_embedding'] = args.word_embedding
228
+ args_dict['word_path'] = args.word_path
229
+ args_dict['use_char'] = args.use_char
230
+ args_dict['char_embedding'] = args.char_embedding
231
+ args_dict['pos_embedding'] = args.pos_embedding
232
+ args_dict['char_path'] = args.char_path
233
+ args_dict['pos_path'] = args.pos_path
234
+ args_dict['use_pos'] = args.use_pos
235
+ args_dict['pos_dim'] = args.pos_dim
236
+ args_dict['word_dict'] = None
237
+ args_dict['word_dim'] = args.word_dim
238
+ if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
239
+ args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
240
+ args_dict['word_path'])
241
+ args_dict['char_dict'] = None
242
+ args_dict['char_dim'] = args.char_dim
243
+ if args_dict['char_embedding'] != 'random':
244
+ args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
245
+ args_dict['char_path'])
246
+ args_dict['pos_dict'] = None
247
+ if args_dict['pos_embedding'] != 'random':
248
+ args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
249
+ args_dict['pos_path'])
250
+ args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
251
+ args_dict['alphabet_parser_path'] = path.join(args_dict['parser_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
252
+ args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
253
+ args_dict['eval_mode'] = args.eval_mode
254
+ args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
255
+ args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
256
+ args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
257
+ args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
258
+ logger.info("Saving arguments to file")
259
+ save_args(args, args_dict['full_model_name'])
260
+ logger.info("Creating Alphabets")
261
+ alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_parser_path'], args_dict['alphabet_data_paths'])
262
+ args_dict = {**args_dict, **alphabet_dict}
263
+ ARGS = namedtuple('ARGS', args_dict.keys())
264
+ my_args = ARGS(**args_dict)
265
+ return my_args
266
+
267
+
268
+ def creating_alphabets(alphabet_path, alphabet_parser_path, alphabet_data_paths):
269
+ data_paths_list = alphabet_data_paths.values()
270
+ alphabet_dict = {}
271
+ alphabet_dict['alphabets'] = prepare_data.create_alphabets_for_sequence_tagger(alphabet_path, alphabet_parser_path, data_paths_list)
272
+ for k, v in alphabet_dict['alphabets'].items():
273
+ num_key = 'num_' + k.split('_alphabet')[0]
274
+ alphabet_dict[num_key] = v.size()
275
+ logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
276
+ return alphabet_dict
277
+
278
+ def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
279
+ if tokens_dict is None:
280
+ return None
281
+ scale = np.sqrt(3.0 / dim)
282
+ table = np.empty([alphabet.size(), dim], dtype=np.float32)
283
+ table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
284
+ oov_tokens = 0
285
+ for token, index in alphabet.items():
286
+ if token in tokens_dict:
287
+ embedding = tokens_dict[token]
288
+ if token =='ata':
289
+ embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
290
+ else:
291
+ embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
292
+ oov_tokens += 1
293
+ table[index, :] = embedding
294
+ print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
295
+ table = torch.from_numpy(table)
296
+ return table
297
+
298
+ def get_free_gpu():
299
+ system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > tmp.txt')
300
+ memory_available = [int(x.split()[2]) for x in open('tmp.txt', 'r').readlines()]
301
+ remove("tmp.txt")
302
+ free_device = 'cuda:' + str(np.argmax(memory_available))
303
+ return free_device
304
+
305
+ def save_args(args, full_model_name):
306
+ arg_path = full_model_name + '.arg.json'
307
+ argparse_dict = vars(args)
308
+ with open(arg_path, 'w') as f:
309
+ json.dump(argparse_dict, f)
310
+
311
+ def generate_optimizer(args, lr, params):
312
+ params = filter(lambda param: param.requires_grad, params)
313
+ if args.opt == 'adam':
314
+ return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
315
+ elif args.opt == 'sgd':
316
+ return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
317
+ else:
318
+ raise ValueError('Unknown optimization algorithm: %s' % args.opt)
319
+
320
+
321
+ def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
322
+ path_name = full_model_name + '.pt'
323
+ print('Saving model to: %s' % path_name)
324
+ state = {'model_state_dict': model.state_dict(),
325
+ 'optimizer_state_dict': optimizer.state_dict(),
326
+ 'opt': opt, 'dev_eval_dict': dev_eval_dict, 'test_eval_dict': test_eval_dict}
327
+ torch.save(state, path_name)
328
+
329
+
330
+ def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
331
+ print('Loading saved model from: %s' % load_path)
332
+ checkpoint = torch.load(load_path, map_location=args.device)
333
+ if checkpoint['opt'] != args.opt:
334
+ raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
335
+ model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
336
+ if strict:
337
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
338
+ dev_eval_dict = checkpoint['dev_eval_dict']
339
+ test_eval_dict = checkpoint['test_eval_dict']
340
+ start_epoch = dev_eval_dict['in_domain']['epoch']
341
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
342
+
343
+
344
+ def build_model_and_optimizer(args):
345
+ word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
346
+ char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
347
+ pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
348
+ model = Sequence_Tagger(args.word_dim, args.num_word, args.char_dim, args.num_char,
349
+ args.use_pos, args.use_char, args.pos_dim, args.num_pos,
350
+ args.num_filters, args.kernel_size, args.rnn_mode,
351
+ args.hidden_size, args.num_layers, args.tag_space, args.num_auto_label,
352
+ embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
353
+ p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
354
+ initializer=args.initializer)
355
+ optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
356
+ start_epoch = 0
357
+ dev_eval_dict = {'in_domain': initialize_eval_dict()}
358
+ test_eval_dict = {'in_domain': initialize_eval_dict()}
359
+ if args.load_path:
360
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
361
+ load_checkpoint(args, model, optimizer,
362
+ dev_eval_dict, test_eval_dict,
363
+ start_epoch, args.load_path, strict=args.strict)
364
+ if args.freeze_word_embeddings:
365
+ model.rnn_encoder.word_embedd.weight.requires_grad = False
366
+ # model.rnn_encoder.char_embedd.weight.requires_grad = False
367
+ # model.rnn_encoder.pos_embedd.weight.requires_grad = False
368
+ device = args.device
369
+ model.to(device)
370
+ return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
371
+
372
+
373
+ def initialize_eval_dict():
374
+ eval_dict = {}
375
+ eval_dict['auto_label_accuracy'] = 0.0
376
+ eval_dict['auto_label_precision'] = 0.0
377
+ eval_dict['auto_label_recall'] = 0.0
378
+ eval_dict['auto_label_f1'] = 0.0
379
+ return eval_dict
380
+
381
+ def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
382
+ best_model, best_optimizer, patient):
383
+ # In-domain evaluation
384
+ curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
385
+ is_best_in_domain = dev_eval_dict['in_domain']['auto_label_f1'] <= curr_dev_eval_dict['auto_label_f1']
386
+
387
+ if is_best_in_domain:
388
+ for key, value in curr_dev_eval_dict.items():
389
+ dev_eval_dict['in_domain'][key] = value
390
+ curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
391
+ for key, value in curr_test_eval_dict.items():
392
+ test_eval_dict['in_domain'][key] = value
393
+ best_model = deepcopy(model)
394
+ best_optimizer = deepcopy(optimizer)
395
+ patient = 0
396
+ else:
397
+ patient += 1
398
+ if epoch == args.num_epochs:
399
+ # save in-domain checkpoint
400
+ for split in ['dev', 'test']:
401
+ eval_dict = dev_eval_dict['in_domain'] if split == 'dev' else test_eval_dict['in_domain']
402
+ write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
403
+ save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
404
+
405
+ print('\n')
406
+ return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient, curr_dev_eval_dict
407
+
408
+
409
+ def evaluation(args, data, split, model, domain, epoch, str_res='results'):
410
+ # evaluate performance on data
411
+ model.eval()
412
+ auto_label_idx2inst = Index2Instance(args.alphabets['auto_label_alphabet'])
413
+ eval_dict = initialize_eval_dict()
414
+ eval_dict['epoch'] = epoch
415
+ pred_labels = []
416
+ gold_labels = []
417
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
418
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
419
+ output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
420
+ auto_label_preds = model.decode(output, mask=masks, length=lengths, leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
421
+ lengths = lengths.cpu().numpy()
422
+ word = word.data.cpu().numpy()
423
+ pos = pos.data.cpu().numpy()
424
+ ner = ner.data.cpu().numpy()
425
+ heads = heads.data.cpu().numpy()
426
+ arc_tags = arc_tags.data.cpu().numpy()
427
+ auto_label = auto_label.data.cpu().numpy()
428
+ auto_label_preds = auto_label_preds.data.cpu().numpy()
429
+ gold_labels += auto_label_idx2inst.index2instance(auto_label, lengths, symbolic_root=True)
430
+ pred_labels += auto_label_idx2inst.index2instance(auto_label_preds, lengths, symbolic_root=True)
431
+
432
+ eval_dict['auto_label_accuracy'] = accuracy_score(gold_labels, pred_labels) * 100
433
+ eval_dict['auto_label_precision'] = precision_score(gold_labels, pred_labels) * 100
434
+ eval_dict['auto_label_recall'] = recall_score(gold_labels, pred_labels) * 100
435
+ eval_dict['auto_label_f1'] = f1_score(gold_labels, pred_labels) * 100
436
+ eval_dict['classification_report'] = classification_report(gold_labels, pred_labels)
437
+ print_results(eval_dict, split, domain, str_res)
438
+ return eval_dict
439
+
440
+
441
+ def print_results(eval_dict, split, domain, str_res='results'):
442
+ print('----------------------------------------------------------------------------------------------------------------------------')
443
+ print('Testing model on domain %s' % domain)
444
+ print('--------------- sequence_tagger - %s ---------------' % split)
445
+ print(
446
+ str_res + ' on ' + split + ' accuracy: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)'
447
+ % (eval_dict['auto_label_accuracy'], eval_dict['auto_label_precision'], eval_dict['auto_label_recall'], eval_dict['auto_label_f1'],
448
+ eval_dict['epoch']))
449
+ print(eval_dict['classification_report'])
450
+
451
+
452
+ def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
453
+ str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
454
+ res_filename = str_file + '_res.txt'
455
+ pred_filename = str_file + '_pred.txt'
456
+ gold_filename = str_file + '_gold.txt'
457
+
458
+ # save results dictionary into a file
459
+ with open(res_filename, 'w') as f:
460
+ json.dump(eval_dict, f)
461
+
462
+ # save predictions and gold labels into files
463
+ pred_writer = Writer(args.alphabets)
464
+ gold_writer = Writer(args.alphabets)
465
+ pred_writer.start(pred_filename)
466
+ gold_writer.start(gold_filename)
467
+ for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
468
+ word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
469
+ output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
470
+ auto_label_preds = model.decode(output, mask=masks, length=lengths,
471
+ leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
472
+ lengths = lengths.cpu().numpy()
473
+ word = word.data.cpu().numpy()
474
+ pos = pos.data.cpu().numpy()
475
+ ner = ner.data.cpu().numpy()
476
+ heads = heads.data.cpu().numpy()
477
+ arc_tags = arc_tags.data.cpu().numpy()
478
+ auto_label_preds = auto_label_preds.data.cpu().numpy()
479
+ # writing predictions
480
+ pred_writer.write(word, pos, ner, heads, arc_tags, lengths, auto_label=auto_label_preds, symbolic_root=True)
481
+ # writing gold labels
482
+ gold_writer.write(word, pos, ner, heads, arc_tags, lengths, auto_label=auto_label, symbolic_root=True)
483
+
484
+ pred_writer.close()
485
+ gold_writer.close()
486
+
487
+ def main():
488
+ logger.info("Reading and creating arguments")
489
+ args = read_arguments()
490
+ logger.info("Reading Data")
491
+ datasets = {}
492
+ for split in args.splits:
493
+ dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
494
+ symbolic_root=True)
495
+ datasets[split] = dataset
496
+
497
+ logger.info("Creating Networks")
498
+ num_data = sum(datasets['train'][1])
499
+ model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
500
+ best_model = deepcopy(model)
501
+ best_optimizer = deepcopy(optimizer)
502
+ logger.info('Training INFO of in domain %s' % args.domain)
503
+ logger.info('Training on Dependecy Parsing')
504
+ print(model)
505
+ logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
506
+ logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
507
+ logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
508
+ logger.info("num_epochs: %d" % (args.num_epochs))
509
+ print('\n')
510
+
511
+ if not args.eval_mode:
512
+ logger.info("Training")
513
+ num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
514
+ lr = args.learning_rate
515
+ patient = 0
516
+ terminal_patient = 0
517
+ decay = 0
518
+ for epoch in range(start_epoch + 1, args.num_epochs + 1):
519
+ print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
520
+ epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
521
+ model.train()
522
+ total_loss = 0.0
523
+ total_train_inst = 0.0
524
+
525
+ iter = prepare_data.iterate_batch_rand_bucket_choosing(
526
+ datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
527
+ start_time = time.time()
528
+ batch_num = 0
529
+ for batch_num, batch in enumerate(iter):
530
+ batch_num = batch_num + 1
531
+ optimizer.zero_grad()
532
+ # compute loss of main task
533
+ word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
534
+ output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
535
+ loss = model.loss(output, auto_label, mask=masks, length=lengths)
536
+
537
+ # update losses
538
+ num_insts = masks.data.sum() - word.size(0)
539
+ total_loss += loss.item() * num_insts
540
+ total_train_inst += num_insts
541
+ # optimize parameters
542
+ loss.backward()
543
+ clip_grad_norm_(model.parameters(), args.clip)
544
+ optimizer.step()
545
+
546
+ time_ave = (time.time() - start_time) / batch_num
547
+ time_left = (num_batches - batch_num) * time_ave
548
+
549
+ # update log
550
+ if batch_num % 50 == 0:
551
+ log_info = 'train: %d/%d, domain: %s, total loss: %.2f, time left: %.2fs' % \
552
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, time_left)
553
+ sys.stdout.write(log_info)
554
+ sys.stdout.write('\n')
555
+ sys.stdout.flush()
556
+ print('\n')
557
+ print('train: %d/%d, domain: %s, total_loss: %.2f, time: %.2fs' %
558
+ (batch_num, num_batches, args.domain, total_loss / total_train_inst, time.time() - start_time))
559
+
560
+ dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient,curr_dev_eval_dict = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
561
+ store ={'dev_eval_dict':curr_dev_eval_dict }
562
+ #############################################
563
+ str_file = args.full_model_name + '_' +'all_epochs'
564
+ with open(str_file,'a') as f:
565
+ f.write(str(store)+'\n')
566
+ if patient == 0:
567
+ terminal_patient = 0
568
+ else:
569
+ terminal_patient += 1
570
+ if terminal_patient >= 4 * args.schedule:
571
+ # Save best model and terminate learning
572
+ cur_epoch = epoch
573
+ epoch = args.num_epochs
574
+ in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model,
575
+ best_optimizer, patient)
576
+ log_info = 'Terminating training in epoch %d' % (cur_epoch)
577
+ sys.stdout.write(log_info)
578
+ sys.stdout.write('\n')
579
+ sys.stdout.flush()
580
+ return
581
+ if patient >= args.schedule:
582
+ lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
583
+ optimizer = generate_optimizer(args, lr, model.parameters())
584
+ print('updated learning rate to %.6f' % lr)
585
+ patient = 0
586
+ print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
587
+ print('\n')
588
+
589
+ else:
590
+ logger.info("Evaluating")
591
+ epoch = start_epoch
592
+ for split in ['train', 'dev', 'test']:
593
+ evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
594
+
595
+
596
+ if __name__ == '__main__':
597
+ main()
examples/eval/conll03eval.v2 ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/perl -w
2
+ # conlleval: evaluate result of processing CoNLL-2000 shared task
3
+ # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file
4
+ # README: http://cnts.uia.ac.be/conll2000/chunking/output.html
5
+ # options: l: generate LaTeX output for tables like in
6
+ # http://cnts.uia.ac.be/conll2003/ner/example.tex
7
+ # r: accept raw result tags (without B- and I- prefix;
8
+ # assumes one word per chunk)
9
+ # d: alternative delimiter tag (default is single space)
10
+ # o: alternative outside tag (default is O)
11
+ # note: the file should contain lines with items separated
12
+ # by $delimiter characters (default space). The final
13
+ # two items should contain the correct tag and the
14
+ # guessed tag in that order. Sentences should be
15
+ # separated from each other by empty lines or lines
16
+ # with $boundary fields (default -X-).
17
+ # url: http://lcg-www.uia.ac.be/conll2000/chunking/
18
+ # started: 1998-09-25
19
+ # version: 2004-01-26
20
+ # author: Erik Tjong Kim Sang <erikt@uia.ua.ac.be>
21
+
22
+ use strict;
23
+
24
+ my $false = 0;
25
+ my $true = 42;
26
+
27
+ my $boundary = "-X-"; # sentence boundary
28
+ my $correct; # current corpus chunk tag (I,O,B)
29
+ my $correctChunk = 0; # number of correctly identified chunks
30
+ my $correctTags = 0; # number of correct chunk tags
31
+ my $correctType; # type of current corpus chunk tag (NP,VP,etc.)
32
+ my $delimiter = " "; # field delimiter
33
+ my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979)
34
+ my $firstItem; # first feature (for sentence boundary checks)
35
+ my $foundCorrect = 0; # number of chunks in corpus
36
+ my $foundGuessed = 0; # number of identified chunks
37
+ my $guessed; # current guessed chunk tag
38
+ my $guessedType; # type of current guessed chunk tag
39
+ my $i; # miscellaneous counter
40
+ my $inCorrect = $false; # currently processed chunk is correct until now
41
+ my $lastCorrect = "O"; # previous chunk tag in corpus
42
+ my $latex = 0; # generate LaTeX formatted output
43
+ my $lastCorrectType = ""; # type of previously identified chunk tag
44
+ my $lastGuessed = "O"; # previously identified chunk tag
45
+ my $lastGuessedType = ""; # type of previous chunk tag in corpus
46
+ my $lastType; # temporary storage for detecting duplicates
47
+ my $line; # line
48
+ my $nbrOfFeatures = -1; # number of features per line
49
+ my $precision = 0.0; # precision score
50
+ my $oTag = "O"; # outside tag, default O
51
+ my $raw = 0; # raw input: add B to every token
52
+ my $recall = 0.0; # recall score
53
+ my $tokenCounter = 0; # token counter (ignores sentence breaks)
54
+
55
+ my %correctChunk = (); # number of correctly identified chunks per type
56
+ my %foundCorrect = (); # number of chunks in corpus per type
57
+ my %foundGuessed = (); # number of identified chunks per type
58
+
59
+ my @features; # features on line
60
+ my @sortedTypes; # sorted list of chunk type names
61
+
62
+ # sanity check
63
+ while (@ARGV and $ARGV[0] =~ /^-/) {
64
+ if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); }
65
+ elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); }
66
+ elsif ($ARGV[0] eq "-d") {
67
+ shift(@ARGV);
68
+ if (not defined $ARGV[0]) {
69
+ die "conlleval: -d requires delimiter character";
70
+ }
71
+ $delimiter = shift(@ARGV);
72
+ } elsif ($ARGV[0] eq "-o") {
73
+ shift(@ARGV);
74
+ if (not defined $ARGV[0]) {
75
+ die "conlleval: -o requires delimiter character";
76
+ }
77
+ $oTag = shift(@ARGV);
78
+ } else { die "conlleval: unknown argument $ARGV[0]\n"; }
79
+ }
80
+ if (@ARGV) { die "conlleval: unexpected command line argument\n"; }
81
+ # process input
82
+ while (<STDIN>) {
83
+ chomp($line = $_);
84
+ @features = split(/$delimiter/,$line);
85
+ if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; }
86
+ elsif ($nbrOfFeatures != $#features and @features != 0) {
87
+ printf STDERR "unexpected number of features: %d (%d)\n",
88
+ $#features+1,$nbrOfFeatures+1;
89
+ exit(1);
90
+ }
91
+ if (@features == 0 or
92
+ $features[0] eq $boundary) { @features = ($boundary,"O","O"); }
93
+ if (@features < 2) {
94
+ die "conlleval: unexpected number of features in line $line\n";
95
+ }
96
+ if ($raw) {
97
+ if ($features[$#features] eq $oTag) { $features[$#features] = "O"; }
98
+ if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; }
99
+ if ($features[$#features] ne "O") {
100
+ $features[$#features] = "B-$features[$#features]";
101
+ }
102
+ if ($features[$#features-1] ne "O") {
103
+ $features[$#features-1] = "B-$features[$#features-1]";
104
+ }
105
+ }
106
+ # 20040126 ET code which allows hyphens in the types
107
+ if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
108
+ $guessed = $1;
109
+ $guessedType = $2;
110
+ } else {
111
+ $guessed = $features[$#features];
112
+ $guessedType = "";
113
+ }
114
+ pop(@features);
115
+ if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
116
+ $correct = $1;
117
+ $correctType = $2;
118
+ } else {
119
+ $correct = $features[$#features];
120
+ $correctType = "";
121
+ }
122
+ pop(@features);
123
+ # ($guessed,$guessedType) = split(/-/,pop(@features));
124
+ # ($correct,$correctType) = split(/-/,pop(@features));
125
+ $guessedType = $guessedType ? $guessedType : "";
126
+ $correctType = $correctType ? $correctType : "";
127
+ $firstItem = shift(@features);
128
+
129
+ # 1999-06-26 sentence breaks should always be counted as out of chunk
130
+ if ( $firstItem eq $boundary ) { $guessed = "O"; }
131
+
132
+ if ($inCorrect) {
133
+ if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
134
+ &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
135
+ $lastGuessedType eq $lastCorrectType) {
136
+ $inCorrect=$false;
137
+ $correctChunk++;
138
+ $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
139
+ $correctChunk{$lastCorrectType}+1 : 1;
140
+ } elsif (
141
+ &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) !=
142
+ &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or
143
+ $guessedType ne $correctType ) {
144
+ $inCorrect=$false;
145
+ }
146
+ }
147
+
148
+ if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
149
+ &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
150
+ $guessedType eq $correctType) { $inCorrect = $true; }
151
+
152
+ if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) {
153
+ $foundCorrect++;
154
+ $foundCorrect{$correctType} = $foundCorrect{$correctType} ?
155
+ $foundCorrect{$correctType}+1 : 1;
156
+ }
157
+ if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) {
158
+ $foundGuessed++;
159
+ $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ?
160
+ $foundGuessed{$guessedType}+1 : 1;
161
+ }
162
+ if ( $firstItem ne $boundary ) {
163
+ if ( $correct eq $guessed and $guessedType eq $correctType ) {
164
+ $correctTags++;
165
+ }
166
+ $tokenCounter++;
167
+ }
168
+
169
+ $lastGuessed = $guessed;
170
+ $lastCorrect = $correct;
171
+ $lastGuessedType = $guessedType;
172
+ $lastCorrectType = $correctType;
173
+ }
174
+ if ($inCorrect) {
175
+ $correctChunk++;
176
+ $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
177
+ $correctChunk{$lastCorrectType}+1 : 1;
178
+ }
179
+
180
+ if (not $latex) {
181
+ # compute overall precision, recall and FB1 (default values are 0.0)
182
+ $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
183
+ $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
184
+ $FB1 = 2*$precision*$recall/($precision+$recall)
185
+ if ($precision+$recall > 0);
186
+
187
+ # print overall performance
188
+ printf "processed $tokenCounter tokens with $foundCorrect phrases; ";
189
+ printf "found: $foundGuessed phrases; correct: $correctChunk.\n";
190
+ if ($tokenCounter>0) {
191
+ printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter;
192
+ printf "precision: %6.2f%%; ",$precision;
193
+ printf "recall: %6.2f%%; ",$recall;
194
+ printf "FB1: %6.2f\n",$FB1;
195
+ }
196
+ }
197
+
198
+ # sort chunk type names
199
+ undef($lastType);
200
+ @sortedTypes = ();
201
+ foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) {
202
+ if (not($lastType) or $lastType ne $i) {
203
+ push(@sortedTypes,($i));
204
+ }
205
+ $lastType = $i;
206
+ }
207
+ # print performance per chunk type
208
+ if (not $latex) {
209
+ for $i (@sortedTypes) {
210
+ $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
211
+ if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; }
212
+ else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
213
+ if (not($foundCorrect{$i})) { $recall = 0.0; }
214
+ else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
215
+ if ($precision+$recall == 0.0) { $FB1 = 0.0; }
216
+ else { $FB1 = 2*$precision*$recall/($precision+$recall); }
217
+ printf "%17s: ",$i;
218
+ printf "precision: %6.2f%%; ",$precision;
219
+ printf "recall: %6.2f%%; ",$recall;
220
+ printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i};
221
+ }
222
+ } else {
223
+ print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline";
224
+ for $i (@sortedTypes) {
225
+ $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
226
+ if (not($foundGuessed{$i})) { $precision = 0.0; }
227
+ else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
228
+ if (not($foundCorrect{$i})) { $recall = 0.0; }
229
+ else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
230
+ if ($precision+$recall == 0.0) { $FB1 = 0.0; }
231
+ else { $FB1 = 2*$precision*$recall/($precision+$recall); }
232
+ printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\",
233
+ $i,$precision,$recall,$FB1;
234
+ }
235
+ print "\\hline\n";
236
+ $precision = 0.0;
237
+ $recall = 0;
238
+ $FB1 = 0.0;
239
+ $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
240
+ $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
241
+ $FB1 = 2*$precision*$recall/($precision+$recall)
242
+ if ($precision+$recall > 0);
243
+ printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n",
244
+ $precision,$recall,$FB1;
245
+ }
246
+
247
+ exit 0;
248
+
249
+ # endOfChunk: checks if a chunk ended between the previous and current word
250
+ # arguments: previous and current chunk tags, previous and current types
251
+ # note: this code is capable of handling other chunk representations
252
+ # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
253
+ # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
254
+
255
+ sub endOfChunk {
256
+ my $prevTag = shift(@_);
257
+ my $tag = shift(@_);
258
+ my $prevType = shift(@_);
259
+ my $type = shift(@_);
260
+ my $chunkEnd = $false;
261
+
262
+ if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; }
263
+ if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; }
264
+ if ( $prevTag eq "B" and $tag eq "S" ) { $chunkEnd = $true; }
265
+
266
+ if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; }
267
+ if ( $prevTag eq "I" and $tag eq "S" ) { $chunkEnd = $true; }
268
+ if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
269
+
270
+ if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; }
271
+ if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; }
272
+ if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; }
273
+ if ( $prevTag eq "E" and $tag eq "S" ) { $chunkEnd = $true; }
274
+ if ( $prevTag eq "E" and $tag eq "B" ) { $chunkEnd = $true; }
275
+
276
+ if ( $prevTag eq "S" and $tag eq "E" ) { $chunkEnd = $true; }
277
+ if ( $prevTag eq "S" and $tag eq "I" ) { $chunkEnd = $true; }
278
+ if ( $prevTag eq "S" and $tag eq "O" ) { $chunkEnd = $true; }
279
+ if ( $prevTag eq "S" and $tag eq "S" ) { $chunkEnd = $true; }
280
+ if ( $prevTag eq "S" and $tag eq "B" ) { $chunkEnd = $true; }
281
+
282
+
283
+ if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) {
284
+ $chunkEnd = $true;
285
+ }
286
+
287
+ # corrected 1998-12-22: these chunks are assumed to have length 1
288
+ if ( $prevTag eq "]" ) { $chunkEnd = $true; }
289
+ if ( $prevTag eq "[" ) { $chunkEnd = $true; }
290
+
291
+ return($chunkEnd);
292
+ }
293
+
294
+ # startOfChunk: checks if a chunk started between the previous and current word
295
+ # arguments: previous and current chunk tags, previous and current types
296
+ # note: this code is capable of handling other chunk representations
297
+ # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
298
+ # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
299
+
300
+ sub startOfChunk {
301
+ my $prevTag = shift(@_);
302
+ my $tag = shift(@_);
303
+ my $prevType = shift(@_);
304
+ my $type = shift(@_);
305
+ my $chunkStart = $false;
306
+
307
+ if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; }
308
+ if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; }
309
+ if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; }
310
+ if ( $prevTag eq "S" and $tag eq "B" ) { $chunkStart = $true; }
311
+ if ( $prevTag eq "E" and $tag eq "B" ) { $chunkStart = $true; }
312
+
313
+ if ( $prevTag eq "B" and $tag eq "S" ) { $chunkStart = $true; }
314
+ if ( $prevTag eq "I" and $tag eq "S" ) { $chunkStart = $true; }
315
+ if ( $prevTag eq "O" and $tag eq "S" ) { $chunkStart = $true; }
316
+ if ( $prevTag eq "S" and $tag eq "S" ) { $chunkStart = $true; }
317
+ if ( $prevTag eq "E" and $tag eq "S" ) { $chunkStart = $true; }
318
+
319
+ if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
320
+ if ( $prevTag eq "S" and $tag eq "I" ) { $chunkStart = $true; }
321
+ if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; }
322
+
323
+ if ( $prevTag eq "S" and $tag eq "E" ) { $chunkStart = $true; }
324
+ if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; }
325
+ if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; }
326
+
327
+ if ($tag ne "O" and $tag ne "." and $prevType ne $type) {
328
+ $chunkStart = $true;
329
+ }
330
+
331
+ # corrected 1998-12-22: these chunks are assumed to have length 1
332
+ if ( $tag eq "[" ) { $chunkStart = $true; }
333
+ if ( $tag eq "]" ) { $chunkStart = $true; }
334
+
335
+ return($chunkStart);
336
+ }
examples/eval/conll06eval.pl ADDED
@@ -0,0 +1,1826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env perl
2
+
3
+ # Author: Yuval Krymolowski
4
+ # Addition of precision and recall
5
+ # and of frame confusion list: Sabine Buchholz
6
+ # Addition of DEPREL + ATTACHMENT:
7
+ # Prokopis Prokopidis (prokopis at ilsp dot gr)
8
+ # Acknowledgements:
9
+ # to Markus Kuhn for suggesting the use of
10
+ # the Unicode category property
11
+
12
+ if ($] < 5.008001)
13
+ {
14
+ printf STDERR <<EOM
15
+
16
+ This script requires PERL 5.8.1 for running.
17
+ The new version is needed for proper handling
18
+ of Unicode characters.
19
+
20
+ Please obtain a new version or contact the shared task team
21
+ if you are unable to upgrade PERL.
22
+
23
+ EOM
24
+ ;
25
+ exit(1) ;
26
+ }
27
+
28
+ require Encode;
29
+
30
+ use strict ;
31
+ use warnings;
32
+ use Getopt::Std ;
33
+
34
+ my ($usage) = <<EOT
35
+
36
+ CoNLL-X evaluation script:
37
+
38
+ [perl] eval.pl [OPTIONS] -g <gold standard> -s <system output>
39
+
40
+ This script evaluates a system output with respect to a gold standard.
41
+ Both files should be in UTF-8 encoded CoNLL-X tabular format.
42
+
43
+ Punctuation tokens (those where all characters have the Unicode
44
+ category property "Punctuation") are ignored for scoring (unless the
45
+ -p flag is used).
46
+
47
+ The output breaks down the errors according to their type and context.
48
+
49
+ Optional parameters:
50
+ -o FILE : output: print output to FILE (default is standard output)
51
+ -q : quiet: only print overall performance, without the details
52
+ -b : evalb: produce output in a format similar to evalb
53
+ (http://nlp.cs.nyu.edu/evalb/); use together with -q
54
+ -p : punctuation: also score on punctuation (default is not to score on it)
55
+ -v : version: show the version number
56
+ -h : help: print this help text and exit
57
+
58
+ EOT
59
+ ;
60
+
61
+ my ($line_num) ;
62
+ my ($sep) = '0x01' ;
63
+
64
+ my ($START) = '.S' ;
65
+ my ($END) = '.E' ;
66
+
67
+ my ($con_err_num) = 3 ;
68
+ my ($freq_err_num) = 10 ;
69
+ my ($spec_err_loc_con) = 8 ;
70
+
71
+ ################################################################################
72
+ ### subfunctions ###
73
+ ################################################################################
74
+
75
+ # Whether a string consists entirely of characters with the Unicode
76
+ # category property "Punctuation" (see "man perlunicode")
77
+ sub is_uni_punct
78
+ {
79
+ my ($word) = @_ ;
80
+
81
+ return scalar(Encode::decode_utf8($word)=~ /^\p{Punctuation}+$/) ;
82
+ }
83
+
84
+ # The length of a unicode string, excluding non-spacing marks
85
+ # (for example vowel marks in Arabic)
86
+
87
+ sub uni_len
88
+ {
89
+ my ($word) = @_ ;
90
+ my ($ch, $l) ;
91
+
92
+ $l = 0 ;
93
+ foreach $ch (split(//, Encode::decode_utf8($word)))
94
+ {
95
+ if ($ch !~ /^\p{NonspacingMark}/)
96
+ {
97
+ $l++ ;
98
+ }
99
+ }
100
+
101
+ return $l ;
102
+ }
103
+
104
+ sub filter_context_counts
105
+ { # filter_context_counts
106
+
107
+ my ($vec, $num, $max_len) = @_ ;
108
+ my ($con, $l, $thresh) ;
109
+
110
+ $thresh = (sort {$b <=> $a} values %{$vec})[$num-1] ;
111
+
112
+ foreach $con (keys %{$vec})
113
+ {
114
+ if (${$vec}{$con} < $thresh)
115
+ {
116
+ delete ${$vec}{$con} ;
117
+ next ;
118
+ }
119
+
120
+ $l = uni_len($con) ;
121
+
122
+ if ($l > ${$max_len})
123
+ {
124
+ ${$max_len} = $l ;
125
+ }
126
+ }
127
+
128
+ } # filter_context_counts
129
+
130
+ sub print_context
131
+ { # print_context
132
+
133
+ my ($counts, $counts_pos, $max_con_len, $max_con_pos_len) = @_ ;
134
+ my (@v_con, @v_con_pos, $con, $con_pos, $i, $n) ;
135
+
136
+ printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s", $max_con_pos_len, 'CPOS', 'any', 'head', 'dep', 'both' ;
137
+ printf OUT " ||" ;
138
+ printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s", $max_con_len, 'word', 'any', 'head', 'dep', 'both' ;
139
+ printf OUT "\n" ;
140
+ printf OUT " %s-+------+------+------+-----", '-' x $max_con_pos_len;
141
+ printf OUT "--++" ;
142
+ printf OUT "--%s-+------+------+------+-----", '-' x $max_con_len;
143
+ printf OUT "\n" ;
144
+
145
+ @v_con = sort {${$counts}{tot}{$b} <=> ${$counts}{tot}{$a}} keys %{${$counts}{tot}} ;
146
+ @v_con_pos = sort {${$counts_pos}{tot}{$b} <=> ${$counts_pos}{tot}{$a}} keys %{${$counts_pos}{tot}} ;
147
+
148
+ $n = scalar @v_con ;
149
+ if (scalar @v_con_pos > $n)
150
+ {
151
+ $n = scalar @v_con_pos ;
152
+ }
153
+
154
+ foreach $i (0 .. $n-1)
155
+ {
156
+ if (defined $v_con_pos[$i])
157
+ {
158
+ $con_pos = $v_con_pos[$i] ;
159
+ printf OUT " %-*s | %4d | %4d | %4d | %4d",
160
+ $max_con_pos_len, $con_pos, ${$counts_pos}{tot}{$con_pos},
161
+ ${$counts_pos}{err_head}{$con_pos}, ${$counts_pos}{err_dep}{$con_pos},
162
+ ${$counts_pos}{err_dep}{$con_pos}+${$counts_pos}{err_head}{$con_pos}-${$counts_pos}{tot}{$con_pos} ;
163
+ }
164
+ else
165
+ {
166
+ printf OUT " %-*s | %4s | %4s | %4s | %4s",
167
+ $max_con_pos_len, ' ', ' ', ' ', ' ', ' ' ;
168
+ }
169
+
170
+ printf OUT " ||" ;
171
+
172
+ if (defined $v_con[$i])
173
+ {
174
+ $con = $v_con[$i] ;
175
+ printf OUT " %-*s | %4d | %4d | %4d | %4d",
176
+ $max_con_len+length($con)-uni_len($con), $con, ${$counts}{tot}{$con},
177
+ ${$counts}{err_head}{$con}, ${$counts}{err_dep}{$con},
178
+ ${$counts}{err_dep}{$con}+${$counts}{err_head}{$con}-${$counts}{tot}{$con} ;
179
+ }
180
+ else
181
+ {
182
+ printf OUT " %-*s | %4s | %4s | %4s | %4s",
183
+ $max_con_len, ' ', ' ', ' ', ' ', ' ' ;
184
+ }
185
+
186
+ printf OUT "\n" ;
187
+ }
188
+
189
+ printf OUT " %s-+------+------+------+-----", '-' x $max_con_pos_len;
190
+ printf OUT "--++" ;
191
+ printf OUT "--%s-+------+------+------+-----", '-' x $max_con_len;
192
+ printf OUT "\n" ;
193
+
194
+ printf OUT "\n\n" ;
195
+
196
+ } # print_context
197
+
198
+ sub num_as_word
199
+ {
200
+ my ($num) = @_ ;
201
+
202
+ $num = abs($num) ;
203
+
204
+ if ($num == 1)
205
+ {
206
+ return ('one word') ;
207
+ }
208
+ elsif ($num == 2)
209
+ {
210
+ return ('two words') ;
211
+ }
212
+ elsif ($num == 3)
213
+ {
214
+ return ('three words') ;
215
+ }
216
+ elsif ($num == 4)
217
+ {
218
+ return ('four words') ;
219
+ }
220
+ else
221
+ {
222
+ return ($num.' words') ;
223
+ }
224
+ }
225
+
226
+ sub describe_err
227
+ { # describe_err
228
+
229
+ my ($head_err, $head_aft_bef, $dep_err) = @_ ;
230
+ my ($dep_g, $dep_s, $desc) ;
231
+ my ($head_aft_bef_g, $head_aft_bef_s) = split(//, $head_aft_bef) ;
232
+
233
+ if ($head_err eq '-')
234
+ {
235
+ $desc = 'correct head' ;
236
+
237
+ if ($head_aft_bef_s eq '0')
238
+ {
239
+ $desc .= ' (0)' ;
240
+ }
241
+ elsif ($head_aft_bef_s eq 'e')
242
+ {
243
+ $desc .= ' (the focus word)' ;
244
+ }
245
+ elsif ($head_aft_bef_s eq 'a')
246
+ {
247
+ $desc .= ' (after the focus word)' ;
248
+ }
249
+ elsif ($head_aft_bef_s eq 'b')
250
+ {
251
+ $desc .= ' (before the focus word)' ;
252
+ }
253
+ }
254
+ elsif ($head_aft_bef_s eq '0')
255
+ {
256
+ $desc = 'head = 0 instead of ' ;
257
+ if ($head_aft_bef_g eq 'a')
258
+ {
259
+ $desc.= 'after ' ;
260
+ }
261
+ if ($head_aft_bef_g eq 'b')
262
+ {
263
+ $desc.= 'before ' ;
264
+ }
265
+ $desc .= 'the focus word' ;
266
+ }
267
+ elsif ($head_aft_bef_g eq '0')
268
+ {
269
+ $desc = 'head is ' ;
270
+ if ($head_aft_bef_g eq 'a')
271
+ {
272
+ $desc.= 'after ' ;
273
+ }
274
+ if ($head_aft_bef_g eq 'b')
275
+ {
276
+ $desc.= 'before ' ;
277
+ }
278
+ $desc .= 'the focus word instead of 0' ;
279
+ }
280
+ else
281
+ {
282
+ $desc = num_as_word($head_err) ;
283
+ if ($head_err < 0)
284
+ {
285
+ $desc .= ' before' ;
286
+ }
287
+ else
288
+ {
289
+ $desc .= ' after' ;
290
+ }
291
+
292
+ $desc = 'head '.$desc.' the correct head ' ;
293
+
294
+ if ($head_aft_bef_s eq '0')
295
+ {
296
+ $desc .= '(0' ;
297
+ }
298
+ elsif ($head_aft_bef_s eq 'e')
299
+ {
300
+ $desc .= '(the focus word' ;
301
+ }
302
+ elsif ($head_aft_bef_s eq 'a')
303
+ {
304
+ $desc .= '(after the focus word' ;
305
+ }
306
+ elsif ($head_aft_bef_s eq 'b')
307
+ {
308
+ $desc .= '(before the focus word' ;
309
+ }
310
+
311
+ if ($head_aft_bef_g ne $head_aft_bef_s)
312
+ {
313
+ $desc .= ' instead of' ;
314
+ if ($head_aft_bef_s eq '0')
315
+ {
316
+ $desc .= '0' ;
317
+ }
318
+ elsif ($head_aft_bef_s eq 'e')
319
+ {
320
+ $desc .= 'the focus word' ;
321
+ }
322
+ elsif ($head_aft_bef_s eq 'a')
323
+ {
324
+ $desc .= 'after the focus word' ;
325
+ }
326
+ elsif ($head_aft_bef_s eq 'b')
327
+ {
328
+ $desc .= 'before the focus word' ;
329
+ }
330
+ }
331
+
332
+ $desc .= ')' ;
333
+ }
334
+
335
+ $desc .= ', ' ;
336
+
337
+ if ($dep_err eq '-')
338
+ {
339
+ $desc .= 'correct dependency' ;
340
+ }
341
+ else
342
+ {
343
+ ($dep_g, $dep_s) = ($dep_err =~ /^(.*)->(.*)$/) ;
344
+ $desc .= sprintf('dependency "%s" instead of "%s"', $dep_s, $dep_g) ;
345
+ }
346
+
347
+ return($desc) ;
348
+
349
+ } # describe_err
350
+
351
+ sub get_context
352
+ { # get_context
353
+
354
+ my ($sent, $i_w) = @_ ;
355
+ my ($w_2, $w_1, $w1, $w2) ;
356
+ my ($p_2, $p_1, $p1, $p2) ;
357
+
358
+ if ($i_w >= 2)
359
+ {
360
+ $w_2 = ${${$sent}[$i_w-2]}{word} ;
361
+ $p_2 = ${${$sent}[$i_w-2]}{pos} ;
362
+ }
363
+ else
364
+ {
365
+ $w_2 = $START ;
366
+ $p_2 = $START ;
367
+ }
368
+
369
+ if ($i_w >= 1)
370
+ {
371
+ $w_1 = ${${$sent}[$i_w-1]}{word} ;
372
+ $p_1 = ${${$sent}[$i_w-1]}{pos} ;
373
+ }
374
+ else
375
+ {
376
+ $w_1 = $START ;
377
+ $p_1 = $START ;
378
+ }
379
+
380
+ if ($i_w <= scalar @{$sent}-2)
381
+ {
382
+ $w1 = ${${$sent}[$i_w+1]}{word} ;
383
+ $p1 = ${${$sent}[$i_w+1]}{pos} ;
384
+ }
385
+ else
386
+ {
387
+ $w1 = $END ;
388
+ $p1 = $END ;
389
+ }
390
+
391
+ if ($i_w <= scalar @{$sent}-3)
392
+ {
393
+ $w2 = ${${$sent}[$i_w+2]}{word} ;
394
+ $p2 = ${${$sent}[$i_w+2]}{pos} ;
395
+ }
396
+ else
397
+ {
398
+ $w2 = $END ;
399
+ $p2 = $END ;
400
+ }
401
+
402
+ return ($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) ;
403
+
404
+ } # get_context
405
+
406
+ sub read_sent
407
+ { # read_sent
408
+
409
+ my ($sent_gold, $sent_sys) = @_ ;
410
+ my ($line_g, $line_s, $new_sent) ;
411
+ my (%fields_g, %fields_s) ;
412
+
413
+ $new_sent = 1 ;
414
+
415
+ @{$sent_gold} = () ;
416
+ @{$sent_sys} = () ;
417
+
418
+ while (1)
419
+ { # main reading loop
420
+
421
+ $line_g = <GOLD> ;
422
+ $line_s = <SYS> ;
423
+
424
+ $line_num++ ;
425
+
426
+ # system output has fewer lines than gold standard
427
+ if ((defined $line_g) && (! defined $line_s))
428
+ {
429
+ printf STDERR "line mismatch, line %d:\n", $line_num ;
430
+ printf STDERR " gold: %s", $line_g ;
431
+ printf STDERR " sys : past end of file\n" ;
432
+ exit(1) ;
433
+ }
434
+
435
+ # system output has more lines than gold standard
436
+ if ((! defined $line_g) && (defined $line_s))
437
+ {
438
+ printf STDERR "line mismatch, line %d:\n", $line_num ;
439
+ printf STDERR " gold: past end of file\n" ;
440
+ printf STDERR " sys : %s", $line_s ;
441
+ exit(1) ;
442
+ }
443
+
444
+ # end of file reached for both
445
+ if ((! defined $line_g) && (! defined $line_s))
446
+ {
447
+ return (1) ;
448
+ }
449
+
450
+ # one contains end of sentence but other one does not
451
+ if (($line_g =~ /^\s+$/) != ($line_s =~ /^\s+$/))
452
+ {
453
+ printf STDERR "line mismatch, line %d:\n", $line_num ;
454
+ printf STDERR " gold: %s", $line_g ;
455
+ printf STDERR " sys : %s", $line_s ;
456
+ exit(1) ;
457
+ }
458
+
459
+ # end of sentence reached
460
+ if ($line_g =~ /^\s+$/)
461
+ {
462
+ return(0) ;
463
+ }
464
+
465
+ # now both lines contain information
466
+
467
+ if ($new_sent)
468
+ {
469
+ $new_sent = 0 ;
470
+ }
471
+
472
+ # 'official' column names
473
+ # options.output = ['id','form','lemma','cpostag','postag',
474
+ # 'feats','head','deprel','phead','pdeprel']
475
+
476
+ @fields_g{'word', 'pos', 'head', 'dep'} = (split (/\s+/, $line_g))[1, 3, 6, 7] ;
477
+
478
+ push @{$sent_gold}, { %fields_g } ;
479
+
480
+ @fields_s{'word', 'pos', 'head', 'dep'} = (split (/\s+/, $line_s))[1, 3, 6, 7] ;
481
+
482
+ if (($fields_g{word} ne $fields_s{word})
483
+ ||
484
+ ($fields_g{pos} ne $fields_s{pos}))
485
+ {
486
+ printf STDERR "Word/pos mismatch, line %d:\n", $line_num ;
487
+ printf STDERR " gold: %s", $line_g ;
488
+ printf STDERR " sys : %s", $line_s ;
489
+ exit(1) ;
490
+ }
491
+
492
+ push @{$sent_sys}, { %fields_s } ;
493
+
494
+ } # main reading loop
495
+
496
+ } # read_sent
497
+
498
+ ################################################################################
499
+ ### main ###
500
+ ################################################################################
501
+
502
+ our ($opt_g, $opt_s, $opt_o, $opt_h, $opt_v, $opt_q, $opt_p, $opt_b) ;
503
+
504
+ my ($sent_num, $eof, $word_num, @err_sent) ;
505
+ my (@sent_gold, @sent_sys, @starts) ;
506
+ my ($word, $pos, $wp, $head_g, $dep_g, $head_s, $dep_s) ;
507
+ my (%counts, $err_head, $err_dep, $con, $con1, $con_pos, $con_pos1, $thresh) ;
508
+ my ($head_err, $dep_err, @cur_err, %err_counts, $err_counter, $err_desc) ;
509
+ my ($loc_con, %loc_con_err_counts, %err_desc) ;
510
+ my ($head_aft_bef_g, $head_aft_bef_s, $head_aft_bef) ;
511
+ my ($con_bef, $con_aft, $con_bef_2, $con_aft_2, @bits, @e_bits, @v_con, @v_con_pos) ;
512
+ my ($con_pos_bef, $con_pos_aft, $con_pos_bef_2, $con_pos_aft_2) ;
513
+ my ($max_word_len, $max_pos_len, $max_con_len, $max_con_pos_len) ;
514
+ my ($max_word_spec_len, $max_con_bef_len, $max_con_aft_len) ;
515
+ my (%freq_err, $err) ;
516
+
517
+ my ($i, $j, $i_w, $l, $n_args) ;
518
+ my ($w_2, $w_1, $w1, $w2) ;
519
+ my ($wp_2, $wp_1, $wp1, $wp2) ;
520
+ my ($p_2, $p_1, $p1, $p2) ;
521
+
522
+ my ($short_output) ;
523
+ my ($score_on_punct) ;
524
+ $counts{punct} = 0; # initialize
525
+
526
+ getopts("g:o:s:qvhpb") ;
527
+
528
+ if (defined $opt_v)
529
+ {
530
+ my $id = '$Id: eval.pl,v 1.9 2006/05/09 20:30:01 yuval Exp $';
531
+ my @parts = split ' ',$id;
532
+ print "Version $parts[2]\n";
533
+ exit(0);
534
+ }
535
+
536
+ if ((defined $opt_h) || ((! defined $opt_g) && (! defined $opt_s)))
537
+ {
538
+ die $usage ;
539
+ }
540
+
541
+ if (! defined $opt_g)
542
+ {
543
+ die "Gold standard file (-g) missing\n" ;
544
+ }
545
+
546
+ if (! defined $opt_s)
547
+ {
548
+ die "System output file (-s) missing\n" ;
549
+ }
550
+
551
+ if (! defined $opt_o)
552
+ {
553
+ $opt_o = '-' ;
554
+ }
555
+
556
+ if (defined $opt_q)
557
+ {
558
+ $short_output = 1 ;
559
+ } else {
560
+ $short_output = 0 ;
561
+ }
562
+
563
+ if (defined $opt_p)
564
+ {
565
+ $score_on_punct = 1 ;
566
+ } else {
567
+ $score_on_punct = 0 ;
568
+ }
569
+
570
+ $line_num = 0 ;
571
+ $sent_num = 0 ;
572
+ $eof = 0 ;
573
+
574
+ @err_sent = () ;
575
+ @starts = () ;
576
+
577
+ %{$err_sent[0]} = () ;
578
+
579
+ $max_pos_len = length('CPOS') ;
580
+
581
+ ################################################################################
582
+ ### reading input ###
583
+ ################################################################################
584
+
585
+ open (GOLD, "<$opt_g") || die "Could not open gold standard file $opt_g\n" ;
586
+ open (SYS, "<$opt_s") || die "Could not open system output file $opt_s\n" ;
587
+ open (OUT, ">$opt_o") || die "Could not open output file $opt_o\n" ;
588
+
589
+
590
+ if (defined $opt_b) { # produce output similar to evalb
591
+ print OUT " Sent. Attachment Correct Scoring \n";
592
+ print OUT " ID Tokens - Unlab. Lab. HEAD HEAD+DEPREL tokens - - - -\n";
593
+ print OUT " ============================================================================\n";
594
+ }
595
+
596
+
597
+ while (! $eof)
598
+ { # main reading loop
599
+
600
+ $starts[$sent_num] = $line_num+1 ;
601
+ $eof = read_sent(\@sent_gold, \@sent_sys) ;
602
+
603
+ $sent_num++ ;
604
+
605
+ %{$err_sent[$sent_num]} = () ;
606
+ $word_num = scalar @sent_gold ;
607
+
608
+ # for accuracy per sentence
609
+ my %sent_counts = ( tot => 0,
610
+ err_any => 0,
611
+ err_head => 0
612
+ );
613
+
614
+ # printf "$sent_num $word_num\n" ;
615
+
616
+ my @frames_g = ('** '); # the initial frame for the virtual root
617
+ my @frames_s = ('** '); # the initial frame for the virtual root
618
+ foreach $i_w (0 .. $word_num-1)
619
+ { # loop on words
620
+ push @frames_g, ''; # initialize
621
+ push @frames_s, ''; # initialize
622
+ }
623
+
624
+ foreach $i_w (0 .. $word_num-1)
625
+ { # loop on words
626
+
627
+ ($word, $pos, $head_g, $dep_g)
628
+ = @{$sent_gold[$i_w]}{'word', 'pos', 'head', 'dep'} ;
629
+ $wp = $word.' / '.$pos ;
630
+
631
+ # printf "%d: %s %s %s %s\n", $i_w, $word, $pos, $head_g, $dep_g ;
632
+
633
+ if ((! $score_on_punct) && is_uni_punct($word))
634
+ {
635
+ $counts{punct}++ ;
636
+ # ignore punctuations
637
+ next ;
638
+ }
639
+
640
+ if (length($pos) > $max_pos_len)
641
+ {
642
+ $max_pos_len = length($pos) ;
643
+ }
644
+
645
+ ($head_s, $dep_s) = @{$sent_sys[$i_w]}{'head', 'dep'} ;
646
+
647
+ $counts{tot}++ ;
648
+ $counts{word}{$wp}{tot}++ ;
649
+ $counts{pos}{$pos}{tot}++ ;
650
+ $counts{head}{$head_g-$i_w-1}{tot}++ ;
651
+
652
+ # for frame confusions
653
+ # add child to frame of parent
654
+ $frames_g[$head_g] .= "$dep_g ";
655
+ $frames_s[$head_s] .= "$dep_s ";
656
+ # add to frame of token itself
657
+ $frames_g[$i_w+1] .= "*$dep_g* "; # $i_w+1 because $i_w starts counting at zero
658
+ $frames_s[$i_w+1] .= "*$dep_g* ";
659
+
660
+ # for precision and recall of DEPREL
661
+ $counts{dep}{$dep_g}{tot}++ ; # counts for gold standard deprels
662
+ $counts{dep2}{$dep_g}{$dep_s}++ ; # counts for confusions
663
+ $counts{dep_s}{$dep_s}{tot}++ ; # counts for system deprels
664
+ $counts{all_dep}{$dep_g} = 1 ; # list of all deprels that occur ...
665
+ $counts{all_dep}{$dep_s} = 1 ; # ... in either gold or system output
666
+
667
+ # for precision and recall of HEAD direction
668
+ my $dir_g;
669
+ if ($head_g == 0) {
670
+ $dir_g = 'to_root';
671
+ } elsif ($head_g < $i_w+1) { # $i_w+1 because $i_w starts counting at zero
672
+ # also below
673
+ $dir_g = 'left';
674
+ } elsif ($head_g > $i_w+1) {
675
+ $dir_g = 'right';
676
+ } else {
677
+ # token links to itself; should never happen in correct gold standard
678
+ $dir_g = 'self';
679
+ }
680
+ my $dir_s;
681
+ if ($head_s == 0) {
682
+ $dir_s = 'to_root';
683
+ } elsif ($head_s < $i_w+1) {
684
+ $dir_s = 'left';
685
+ } elsif ($head_s > $i_w+1) {
686
+ $dir_s = 'right';
687
+ } else {
688
+ # token links to itself; should not happen in good system
689
+ # (but not forbidden in shared task)
690
+ $dir_s = 'self';
691
+ }
692
+ $counts{dir_g}{$dir_g}{tot}++ ; # counts for gold standard head direction
693
+ $counts{dir2}{$dir_g}{$dir_s}++ ; # counts for confusions
694
+ $counts{dir_s}{$dir_s}{tot}++ ; # counts for system head direction
695
+
696
+ # for precision and recall of HEAD distance
697
+ my $dist_g;
698
+ if ($head_g == 0) {
699
+ $dist_g = 'to_root';
700
+ } elsif ( abs($head_g - ($i_w+1)) <= 1 ) {
701
+ $dist_g = '1'; # includes the 'self' cases
702
+ } elsif ( abs($head_g - ($i_w+1)) <= 2 ) {
703
+ $dist_g = '2';
704
+ } elsif ( abs($head_g - ($i_w+1)) <= 6 ) {
705
+ $dist_g = '3-6';
706
+ } else {
707
+ $dist_g = '7-...';
708
+ }
709
+ my $dist_s;
710
+ if ($head_s == 0) {
711
+ $dist_s = 'to_root';
712
+ } elsif ( abs($head_s - ($i_w+1)) <= 1 ) {
713
+ $dist_s = '1'; # includes the 'self' cases
714
+ } elsif ( abs($head_s - ($i_w+1)) <= 2 ) {
715
+ $dist_s = '2';
716
+ } elsif ( abs($head_s - ($i_w+1)) <= 6 ) {
717
+ $dist_s = '3-6';
718
+ } else {
719
+ $dist_s = '7-...';
720
+ }
721
+ $counts{dist_g}{$dist_g}{tot}++ ; # counts for gold standard head distance
722
+ $counts{dist2}{$dist_g}{$dist_s}++ ; # counts for confusions
723
+ $counts{dist_s}{$dist_s}{tot}++ ; # counts for system head distance
724
+
725
+
726
+ $err_head = ($head_g ne $head_s) ; # error in head
727
+ $err_dep = ($dep_g ne $dep_s) ; # error in deprel
728
+
729
+ $head_err = '-' ;
730
+ $dep_err = '-' ;
731
+
732
+ # for accuracy per sentence
733
+ $sent_counts{tot}++ ;
734
+ if ($err_dep || $err_head) {
735
+ $sent_counts{err_any}++ ;
736
+ }
737
+ if ($err_head) {
738
+ $sent_counts{err_head}++ ;
739
+ }
740
+
741
+ # total counts and counts for CPOS involved in errors
742
+
743
+ if ($head_g eq '0')
744
+ {
745
+ $head_aft_bef_g = '0' ;
746
+ }
747
+ elsif ($head_g eq $i_w+1)
748
+ {
749
+ $head_aft_bef_g = 'e' ;
750
+ }
751
+ else
752
+ {
753
+ $head_aft_bef_g = ($head_g <= $i_w+1 ? 'b' : 'a') ;
754
+ }
755
+
756
+ if ($head_s eq '0')
757
+ {
758
+ $head_aft_bef_s = '0' ;
759
+ }
760
+ elsif ($head_s eq $i_w+1)
761
+ {
762
+ $head_aft_bef_s = 'e' ;
763
+ }
764
+ else
765
+ {
766
+ $head_aft_bef_s = ($head_s <= $i_w+1 ? 'b' : 'a') ;
767
+ }
768
+
769
+ $head_aft_bef = $head_aft_bef_g.$head_aft_bef_s ;
770
+
771
+ if ($err_head)
772
+ {
773
+ if ($head_aft_bef_s eq '0')
774
+ {
775
+ $head_err = 0 ;
776
+ }
777
+ else
778
+ {
779
+ $head_err = $head_s-$head_g ;
780
+ }
781
+
782
+ $err_sent[$sent_num]{head}++ ;
783
+ $counts{err_head}{tot}++ ;
784
+ $counts{err_head}{$head_err}++ ;
785
+
786
+ $counts{word}{err_head}{$wp}++ ;
787
+ $counts{pos}{$pos}{err_head}{tot}++ ;
788
+ $counts{pos}{$pos}{err_head}{$head_err}++ ;
789
+ }
790
+
791
+ if ($err_dep)
792
+ {
793
+ $dep_err = $dep_g.'->'.$dep_s ;
794
+ $err_sent[$sent_num]{dep}++ ;
795
+ $counts{err_dep}{tot}++ ;
796
+ $counts{err_dep}{$dep_err}++ ;
797
+
798
+ $counts{word}{err_dep}{$wp}++ ;
799
+ $counts{pos}{$pos}{err_dep}{tot}++ ;
800
+ $counts{pos}{$pos}{err_dep}{$dep_err}++ ;
801
+
802
+ if ($err_head)
803
+ {
804
+ $counts{err_both}++ ;
805
+ $counts{pos}{$pos}{err_both}++ ;
806
+ }
807
+ }
808
+
809
+ ### DEPREL + ATTACHMENT
810
+ if ((!$err_dep) && ($err_head)) {
811
+ $counts{err_head_corr_dep}{tot}++ ;
812
+ $counts{err_head_corr_dep}{$dep_s}++ ;
813
+ }
814
+ ### DEPREL + ATTACHMENT
815
+
816
+ # counts for words involved in errors
817
+
818
+ if (! ($err_head || $err_dep))
819
+ {
820
+ next ;
821
+ }
822
+
823
+ $err_sent[$sent_num]{word}++ ;
824
+ $counts{err_any}++ ;
825
+ $counts{word}{err_any}{$wp}++ ;
826
+ $counts{pos}{$pos}{err_any}++ ;
827
+
828
+ ($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) = get_context(\@sent_gold, $i_w) ;
829
+
830
+ if ($w_2 ne $START)
831
+ {
832
+ $wp_2 = $w_2.' / '.$p_2 ;
833
+ }
834
+ else
835
+ {
836
+ $wp_2 = $w_2 ;
837
+ }
838
+
839
+ if ($w_1 ne $START)
840
+ {
841
+ $wp_1 = $w_1.' / '.$p_1 ;
842
+ }
843
+ else
844
+ {
845
+ $wp_1 = $w_1 ;
846
+ }
847
+
848
+ if ($w1 ne $END)
849
+ {
850
+ $wp1 = $w1.' / '.$p1 ;
851
+ }
852
+ else
853
+ {
854
+ $wp1 = $w1 ;
855
+ }
856
+
857
+ if ($w2 ne $END)
858
+ {
859
+ $wp2 = $w2.' / '.$p2 ;
860
+ }
861
+ else
862
+ {
863
+ $wp2 = $w2 ;
864
+ }
865
+
866
+ $con_bef = $wp_1 ;
867
+ $con_bef_2 = $wp_2.' + '.$wp_1 ;
868
+ $con_aft = $wp1 ;
869
+ $con_aft_2 = $wp1.' + '.$wp2 ;
870
+
871
+ $con_pos_bef = $p_1 ;
872
+ $con_pos_bef_2 = $p_2.'+'.$p_1 ;
873
+ $con_pos_aft = $p1 ;
874
+ $con_pos_aft_2 = $p1.'+'.$p2 ;
875
+
876
+ if ($w_1 ne $START)
877
+ {
878
+ # do not count '.S' as a word context
879
+ $counts{con_bef_2}{tot}{$con_bef_2}++ ;
880
+ $counts{con_bef_2}{err_head}{$con_bef_2} += $err_head ;
881
+ $counts{con_bef_2}{err_dep}{$con_bef_2} += $err_dep ;
882
+ $counts{con_bef}{tot}{$con_bef}++ ;
883
+ $counts{con_bef}{err_head}{$con_bef} += $err_head ;
884
+ $counts{con_bef}{err_dep}{$con_bef} += $err_dep ;
885
+ }
886
+
887
+ if ($w1 ne $END)
888
+ {
889
+ # do not count '.E' as a word context
890
+ $counts{con_aft_2}{tot}{$con_aft_2}++ ;
891
+ $counts{con_aft_2}{err_head}{$con_aft_2} += $err_head ;
892
+ $counts{con_aft_2}{err_dep}{$con_aft_2} += $err_dep ;
893
+ $counts{con_aft}{tot}{$con_aft}++ ;
894
+ $counts{con_aft}{err_head}{$con_aft} += $err_head ;
895
+ $counts{con_aft}{err_dep}{$con_aft} += $err_dep ;
896
+ }
897
+
898
+ $counts{con_pos_bef_2}{tot}{$con_pos_bef_2}++ ;
899
+ $counts{con_pos_bef_2}{err_head}{$con_pos_bef_2} += $err_head ;
900
+ $counts{con_pos_bef_2}{err_dep}{$con_pos_bef_2} += $err_dep ;
901
+ $counts{con_pos_bef}{tot}{$con_pos_bef}++ ;
902
+ $counts{con_pos_bef}{err_head}{$con_pos_bef} += $err_head ;
903
+ $counts{con_pos_bef}{err_dep}{$con_pos_bef} += $err_dep ;
904
+
905
+ $counts{con_pos_aft_2}{tot}{$con_pos_aft_2}++ ;
906
+ $counts{con_pos_aft_2}{err_head}{$con_pos_aft_2} += $err_head ;
907
+ $counts{con_pos_aft_2}{err_dep}{$con_pos_aft_2} += $err_dep ;
908
+ $counts{con_pos_aft}{tot}{$con_pos_aft}++ ;
909
+ $counts{con_pos_aft}{err_head}{$con_pos_aft} += $err_head ;
910
+ $counts{con_pos_aft}{err_dep}{$con_pos_aft} += $err_dep ;
911
+
912
+ $err = $head_err.$sep.$head_aft_bef.$sep.$dep_err ;
913
+ $freq_err{$err}++ ;
914
+
915
+ } # loop on words
916
+
917
+ foreach $i_w (0 .. $word_num) # including one for the virtual root
918
+ { # loop on words
919
+ if ($frames_g[$i_w] ne $frames_s[$i_w]) {
920
+ $counts{frame2}{"$frames_g[$i_w]/ $frames_s[$i_w]"}++ ;
921
+ }
922
+ }
923
+
924
+ if (defined $opt_b) { # produce output similar to evalb
925
+ if ($word_num > 0) {
926
+ my ($unlabeled,$labeled) = ('NaN', 'NaN');
927
+ if ($sent_counts{tot} > 0) { # there are scoring tokens
928
+ $unlabeled = 100-$sent_counts{err_head}*100.0/$sent_counts{tot};
929
+ $labeled = 100-$sent_counts{err_any} *100.0/$sent_counts{tot};
930
+ }
931
+ printf OUT " %4d %4d 0 %6.2f %6.2f %4d %4d %4d 0 0 0 0\n",
932
+ $sent_num, $word_num,
933
+ $unlabeled, $labeled,
934
+ $sent_counts{tot}-$sent_counts{err_head},
935
+ $sent_counts{tot}-$sent_counts{err_any},
936
+ $sent_counts{tot},;
937
+ }
938
+ }
939
+
940
+ } # main reading loop
941
+
942
+ ################################################################################
943
+ ### printing output ###
944
+ ################################################################################
945
+
946
+ if (defined $opt_b) { # produce output similar to evalb
947
+ print OUT "\n\n";
948
+ }
949
+ printf OUT " Labeled attachment score: %d / %d * 100 = %.2f %%\n",
950
+ $counts{tot}-$counts{err_any}, $counts{tot}, 100-$counts{err_any}*100.0/$counts{tot} ;
951
+ printf OUT " Unlabeled attachment score: %d / %d * 100 = %.2f %%\n",
952
+ $counts{tot}-$counts{err_head}{tot}, $counts{tot}, 100-$counts{err_head}{tot}*100.0/$counts{tot} ;
953
+ printf OUT " Label accuracy score: %d / %d * 100 = %.2f %%\n",
954
+ $counts{tot}-$counts{err_dep}{tot}, $counts{tot}, 100-$counts{err_dep}{tot}*100.0/$counts{tot} ;
955
+
956
+ if ($short_output)
957
+ {
958
+ exit(0) ;
959
+ }
960
+ printf OUT "\n %s\n\n", '=' x 80 ;
961
+ printf OUT " Evaluation of the results in %s\n vs. gold standard %s:\n\n", $opt_s, $opt_g ;
962
+
963
+ printf OUT " Legend: '%s' - the beginning of a sentence, '%s' - the end of a sentence\n\n", $START, $END ;
964
+
965
+ printf OUT " Number of non-scoring tokens: $counts{punct}\n\n";
966
+
967
+ printf OUT " The overall accuracy and its distribution over CPOSTAGs\n\n" ;
968
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
969
+
970
+ printf OUT " %-10s | %-5s | %-5s | %% | %-5s | %% | %-5s | %%\n",
971
+ 'Accuracy', 'words', 'right', 'right', 'both' ;
972
+ printf OUT " %-10s | %-5s | %-5s | | %-5s | | %-5s |\n",
973
+ ' ', ' ', 'head', ' dep', 'right' ;
974
+
975
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
976
+
977
+ printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
978
+ 'total', $counts{tot},
979
+ $counts{tot}-$counts{err_head}{tot}, 100-$counts{err_head}{tot}*100.0/$counts{tot},
980
+ $counts{tot}-$counts{err_dep}{tot}, 100-$counts{err_dep}{tot}*100.0/$counts{tot},
981
+ $counts{tot}-$counts{err_any}, 100-$counts{err_any}*100.0/$counts{tot} ;
982
+
983
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
984
+
985
+ foreach $pos (sort {$counts{pos}{$b}{tot} <=> $counts{pos}{$a}{tot}} keys %{$counts{pos}})
986
+ {
987
+ if (! defined($counts{pos}{$pos}{err_head}{tot}))
988
+ {
989
+ $counts{pos}{$pos}{err_head}{tot} = 0 ;
990
+ }
991
+ if (! defined($counts{pos}{$pos}{err_dep}{tot}))
992
+ {
993
+ $counts{pos}{$pos}{err_dep}{tot} = 0 ;
994
+ }
995
+ if (! defined($counts{pos}{$pos}{err_any}))
996
+ {
997
+ $counts{pos}{$pos}{err_any} = 0 ;
998
+ }
999
+
1000
+ printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
1001
+ $pos, $counts{pos}{$pos}{tot},
1002
+ $counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_head}{tot}, 100-$counts{pos}{$pos}{err_head}{tot}*100.0/$counts{pos}{$pos}{tot},
1003
+ $counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_dep}{tot}, 100-$counts{pos}{$pos}{err_dep}{tot}*100.0/$counts{pos}{$pos}{tot},
1004
+ $counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_any}, 100-$counts{pos}{$pos}{err_any}*100.0/$counts{pos}{$pos}{tot} ;
1005
+ }
1006
+
1007
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
1008
+
1009
+ printf OUT "\n\n" ;
1010
+
1011
+ printf OUT " The overall error rate and its distribution over CPOSTAGs\n\n" ;
1012
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
1013
+
1014
+ printf OUT " %-10s | %-5s | %-5s | %% | %-5s | %% | %-5s | %%\n",
1015
+ 'Error', 'words', 'head', ' dep', 'both' ;
1016
+ printf OUT " %-10s | %-5s | %-5s | | %-5s | | %-5s |\n",
1017
+
1018
+ 'Rate', ' ', 'err', ' err', 'wrong' ;
1019
+
1020
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
1021
+
1022
+ printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
1023
+ 'total', $counts{tot},
1024
+ $counts{err_head}{tot}, $counts{err_head}{tot}*100.0/$counts{tot},
1025
+ $counts{err_dep}{tot}, $counts{err_dep}{tot}*100.0/$counts{tot},
1026
+ $counts{err_both}, $counts{err_both}*100.0/$counts{tot} ;
1027
+
1028
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
1029
+
1030
+ foreach $pos (sort {$counts{pos}{$b}{tot} <=> $counts{pos}{$a}{tot}} keys %{$counts{pos}})
1031
+ {
1032
+ if (! defined($counts{pos}{$pos}{err_both}))
1033
+ {
1034
+ $counts{pos}{$pos}{err_both} = 0 ;
1035
+ }
1036
+
1037
+ printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
1038
+ $pos, $counts{pos}{$pos}{tot},
1039
+ $counts{pos}{$pos}{err_head}{tot}, $counts{pos}{$pos}{err_head}{tot}*100.0/$counts{pos}{$pos}{tot},
1040
+ $counts{pos}{$pos}{err_dep}{tot}, $counts{pos}{$pos}{err_dep}{tot}*100.0/$counts{pos}{$pos}{tot},
1041
+ $counts{pos}{$pos}{err_both}, $counts{pos}{$pos}{err_both}*100.0/$counts{pos}{$pos}{tot} ;
1042
+
1043
+ }
1044
+
1045
+ printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
1046
+
1047
+ ### added by Sabine Buchholz
1048
+ printf OUT "\n\n";
1049
+ printf OUT " Precision and recall of DEPREL\n\n";
1050
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1051
+ printf OUT " deprel | gold | correct | system | recall (%%) | precision (%%) \n";
1052
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1053
+ foreach my $dep (sort keys %{$counts{all_dep}}) {
1054
+ # initialize
1055
+ my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
1056
+
1057
+ if (defined($counts{dep2}{$dep}{$dep})) {
1058
+ $tot_corr = $counts{dep2}{$dep}{$dep};
1059
+ }
1060
+ if (defined($counts{dep}{$dep}{tot})) {
1061
+ $tot_g = $counts{dep}{$dep}{tot};
1062
+ $rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
1063
+ }
1064
+ if (defined($counts{dep_s}{$dep}{tot})) {
1065
+ $tot_s = $counts{dep_s}{$dep}{tot};
1066
+ $prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
1067
+ }
1068
+ printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
1069
+ $dep, $tot_g, $tot_corr, $tot_s, $rec, $prec;
1070
+ }
1071
+
1072
+ ### DEPREL + ATTACHMENT:
1073
+ ### Same as Sabine's DEPREL apart from $tot_corr calculation
1074
+ printf OUT "\n\n";
1075
+ printf OUT " Precision and recall of DEPREL + ATTACHMENT\n\n";
1076
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1077
+ printf OUT " deprel | gold | correct | system | recall (%%) | precision (%%) \n";
1078
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1079
+ foreach my $dep (sort keys %{$counts{all_dep}}) {
1080
+ # initialize
1081
+ my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
1082
+
1083
+ if (defined($counts{dep2}{$dep}{$dep})) {
1084
+ if (defined($counts{err_head_corr_dep}{$dep})) {
1085
+ $tot_corr = $counts{dep2}{$dep}{$dep} - $counts{err_head_corr_dep}{$dep};
1086
+ } else {
1087
+ $tot_corr = $counts{dep2}{$dep}{$dep};
1088
+ }
1089
+ }
1090
+ if (defined($counts{dep}{$dep}{tot})) {
1091
+ $tot_g = $counts{dep}{$dep}{tot};
1092
+ $rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
1093
+ }
1094
+ if (defined($counts{dep_s}{$dep}{tot})) {
1095
+ $tot_s = $counts{dep_s}{$dep}{tot};
1096
+ $prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
1097
+ }
1098
+ printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
1099
+ $dep, $tot_g, $tot_corr, $tot_s, $rec, $prec;
1100
+ }
1101
+ ### DEPREL + ATTACHMENT
1102
+
1103
+ printf OUT "\n\n";
1104
+ printf OUT " Precision and recall of binned HEAD direction\n\n";
1105
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1106
+ printf OUT " direction | gold | correct | system | recall (%%) | precision (%%) \n";
1107
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1108
+ foreach my $dir ('to_root', 'left', 'right', 'self') {
1109
+ # initialize
1110
+ my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
1111
+
1112
+ if (defined($counts{dir2}{$dir}{$dir})) {
1113
+ $tot_corr = $counts{dir2}{$dir}{$dir};
1114
+ }
1115
+ if (defined($counts{dir_g}{$dir}{tot})) {
1116
+ $tot_g = $counts{dir_g}{$dir}{tot};
1117
+ $rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
1118
+ }
1119
+ if (defined($counts{dir_s}{$dir}{tot})) {
1120
+ $tot_s = $counts{dir_s}{$dir}{tot};
1121
+ $prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
1122
+ }
1123
+ printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
1124
+ $dir, $tot_g, $tot_corr, $tot_s, $rec, $prec;
1125
+ }
1126
+
1127
+ printf OUT "\n\n";
1128
+ printf OUT " Precision and recall of binned HEAD distance\n\n";
1129
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1130
+ printf OUT " distance | gold | correct | system | recall (%%) | precision (%%) \n";
1131
+ printf OUT " ----------------+------+---------+--------+------------+---------------\n";
1132
+ foreach my $dist ('to_root', '1', '2', '3-6', '7-...') {
1133
+ # initialize
1134
+ my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
1135
+
1136
+ if (defined($counts{dist2}{$dist}{$dist})) {
1137
+ $tot_corr = $counts{dist2}{$dist}{$dist};
1138
+ }
1139
+ if (defined($counts{dist_g}{$dist}{tot})) {
1140
+ $tot_g = $counts{dist_g}{$dist}{tot};
1141
+ $rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
1142
+ }
1143
+ if (defined($counts{dist_s}{$dist}{tot})) {
1144
+ $tot_s = $counts{dist_s}{$dist}{tot};
1145
+ $prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
1146
+ }
1147
+ printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
1148
+ $dist, $tot_g, $tot_corr, $tot_s, $rec, $prec;
1149
+ }
1150
+
1151
+ printf OUT "\n\n";
1152
+ printf OUT " Frame confusions (gold versus system; *...* marks the head token)\n\n";
1153
+ foreach my $frame (sort {$counts{frame2}{$b} <=> $counts{frame2}{$a}} keys %{$counts{frame2}})
1154
+ {
1155
+ if ($counts{frame2}{$frame} >= 5) # (make 5 a changeable threshold later)
1156
+ {
1157
+ printf OUT " %3d %s\n", $counts{frame2}{$frame}, $frame;
1158
+ }
1159
+ }
1160
+ ### end of: added by Sabine Buchholz
1161
+
1162
+
1163
+ #
1164
+ # Leave only the 5 words mostly involved in errors
1165
+ #
1166
+
1167
+
1168
+ $thresh = (sort {$b <=> $a} values %{$counts{word}{err_any}})[4] ;
1169
+
1170
+ # ensure enough space for title
1171
+ $max_word_len = length('word') ;
1172
+
1173
+ foreach $word (keys %{$counts{word}{err_any}})
1174
+ {
1175
+ if ($counts{word}{err_any}{$word} < $thresh)
1176
+ {
1177
+ delete $counts{word}{err_any}{$word} ;
1178
+ next ;
1179
+ }
1180
+
1181
+ $l = uni_len($word) ;
1182
+ if ($l > $max_word_len)
1183
+ {
1184
+ $max_word_len = $l ;
1185
+ }
1186
+ }
1187
+
1188
+ # filter a case when the difference between the error counts
1189
+ # for 2-word and 1-word contexts is small
1190
+ # (leave the 2-word context)
1191
+
1192
+ foreach $con (keys %{$counts{con_aft_2}{tot}})
1193
+ {
1194
+ ($w1) = split(/\+/, $con) ;
1195
+
1196
+ if (defined $counts{con_aft}{tot}{$w1} &&
1197
+ $counts{con_aft}{tot}{$w1}-$counts{con_aft_2}{tot}{$con} <= 1)
1198
+ {
1199
+ delete $counts{con_aft}{tot}{$w1} ;
1200
+ }
1201
+ }
1202
+
1203
+ foreach $con (keys %{$counts{con_bef_2}{tot}})
1204
+ {
1205
+ ($w_2, $w_1) = split(/\+/, $con) ;
1206
+
1207
+ if (defined $counts{con_bef}{tot}{$w_1} &&
1208
+ $counts{con_bef}{tot}{$w_1}-$counts{con_bef_2}{tot}{$con} <= 1)
1209
+ {
1210
+ delete $counts{con_bef}{tot}{$w_1} ;
1211
+ }
1212
+ }
1213
+
1214
+ foreach $con_pos (keys %{$counts{con_pos_aft_2}{tot}})
1215
+ {
1216
+ ($p1) = split(/\+/, $con_pos) ;
1217
+
1218
+ if (defined($counts{con_pos_aft}{tot}{$p1}) &&
1219
+ $counts{con_pos_aft}{tot}{$p1}-$counts{con_pos_aft_2}{tot}{$con_pos} <= 1)
1220
+ {
1221
+ delete $counts{con_pos_aft}{tot}{$p1} ;
1222
+ }
1223
+ }
1224
+
1225
+ foreach $con_pos (keys %{$counts{con_pos_bef_2}{tot}})
1226
+ {
1227
+ ($p_2, $p_1) = split(/\+/, $con_pos) ;
1228
+
1229
+ if (defined($counts{con_pos_bef}{tot}{$p_1}) &&
1230
+ $counts{con_pos_bef}{tot}{$p_1}-$counts{con_pos_bef_2}{tot}{$con_pos} <= 1)
1231
+ {
1232
+ delete $counts{con_pos_bef}{tot}{$p_1} ;
1233
+ }
1234
+ }
1235
+
1236
+ # for each context type, take the three contexts most involved in errors
1237
+
1238
+ $max_con_len = 0 ;
1239
+
1240
+ filter_context_counts($counts{con_bef_2}{tot}, $con_err_num, \$max_con_len) ;
1241
+
1242
+ filter_context_counts($counts{con_bef}{tot}, $con_err_num, \$max_con_len) ;
1243
+
1244
+ filter_context_counts($counts{con_aft}{tot}, $con_err_num, \$max_con_len) ;
1245
+
1246
+ filter_context_counts($counts{con_aft_2}{tot}, $con_err_num, \$max_con_len) ;
1247
+
1248
+ # for each CPOS context type, take the three CPOS contexts most involved in errors
1249
+
1250
+ $max_con_pos_len = 0 ;
1251
+
1252
+ $thresh = (sort {$b <=> $a} values %{$counts{con_pos_bef_2}{tot}})[$con_err_num-1] ;
1253
+
1254
+ foreach $con_pos (keys %{$counts{con_pos_bef_2}{tot}})
1255
+ {
1256
+ if ($counts{con_pos_bef_2}{tot}{$con_pos} < $thresh)
1257
+ {
1258
+ delete $counts{con_pos_bef_2}{tot}{$con_pos} ;
1259
+ next ;
1260
+ }
1261
+ if (length($con_pos) > $max_con_pos_len)
1262
+ {
1263
+ $max_con_pos_len = length($con_pos) ;
1264
+ }
1265
+ }
1266
+
1267
+ $thresh = (sort {$b <=> $a} values %{$counts{con_pos_bef}{tot}})[$con_err_num-1] ;
1268
+
1269
+ foreach $con_pos (keys %{$counts{con_pos_bef}{tot}})
1270
+ {
1271
+ if ($counts{con_pos_bef}{tot}{$con_pos} < $thresh)
1272
+ {
1273
+ delete $counts{con_pos_bef}{tot}{$con_pos} ;
1274
+ next ;
1275
+ }
1276
+ if (length($con_pos) > $max_con_pos_len)
1277
+ {
1278
+ $max_con_pos_len = length($con_pos) ;
1279
+ }
1280
+ }
1281
+
1282
+ $thresh = (sort {$b <=> $a} values %{$counts{con_pos_aft}{tot}})[$con_err_num-1] ;
1283
+
1284
+ foreach $con_pos (keys %{$counts{con_pos_aft}{tot}})
1285
+ {
1286
+ if ($counts{con_pos_aft}{tot}{$con_pos} < $thresh)
1287
+ {
1288
+ delete $counts{con_pos_aft}{tot}{$con_pos} ;
1289
+ next ;
1290
+ }
1291
+ if (length($con_pos) > $max_con_pos_len)
1292
+ {
1293
+ $max_con_pos_len = length($con_pos) ;
1294
+ }
1295
+ }
1296
+
1297
+ $thresh = (sort {$b <=> $a} values %{$counts{con_pos_aft_2}{tot}})[$con_err_num-1] ;
1298
+
1299
+ foreach $con_pos (keys %{$counts{con_pos_aft_2}{tot}})
1300
+ {
1301
+ if ($counts{con_pos_aft_2}{tot}{$con_pos} < $thresh)
1302
+ {
1303
+ delete $counts{con_pos_aft_2}{tot}{$con_pos} ;
1304
+ next ;
1305
+ }
1306
+ if (length($con_pos) > $max_con_pos_len)
1307
+ {
1308
+ $max_con_pos_len = length($con_pos) ;
1309
+ }
1310
+ }
1311
+
1312
+ # printing
1313
+
1314
+ # ------------- focus words
1315
+
1316
+ printf OUT "\n\n" ;
1317
+ printf OUT " %d focus words where most of the errors occur:\n\n", scalar keys %{$counts{word}{err_any}} ;
1318
+
1319
+ printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s\n", $max_word_len, ' ', 'any', 'head', 'dep', 'both' ;
1320
+ printf OUT " %s-+------+------+------+------\n", '-' x $max_word_len;
1321
+
1322
+ foreach $word (sort {$counts{word}{err_any}{$b} <=> $counts{word}{err_any}{$a}} keys %{$counts{word}{err_any}})
1323
+ {
1324
+ if (!defined($counts{word}{err_head}{$word}))
1325
+ {
1326
+ $counts{word}{err_head}{$word} = 0 ;
1327
+ }
1328
+ if (! defined($counts{word}{err_dep}{$word}))
1329
+ {
1330
+ $counts{word}{err_dep}{$word} = 0 ;
1331
+ }
1332
+ if (! defined($counts{word}{err_any}{$word}))
1333
+ {
1334
+ $counts{word}{err_any}{$word} = 0;
1335
+ }
1336
+ printf OUT " %-*s | %4d | %4d | %4d | %4d\n",
1337
+ $max_word_len+length($word)-uni_len($word), $word, $counts{word}{err_any}{$word},
1338
+ $counts{word}{err_head}{$word},
1339
+ $counts{word}{err_dep}{$word},
1340
+ $counts{word}{err_dep}{$word}+$counts{word}{err_head}{$word}-$counts{word}{err_any}{$word} ;
1341
+ }
1342
+
1343
+ printf OUT " %s-+------+------+------+------\n", '-' x $max_word_len;
1344
+
1345
+ # ------------- contexts
1346
+
1347
+ printf OUT "\n\n" ;
1348
+
1349
+ printf OUT " one-token preceeding contexts where most of the errors occur:\n\n" ;
1350
+
1351
+ print_context($counts{con_bef}, $counts{con_pos_bef}, $max_con_len, $max_con_pos_len) ;
1352
+
1353
+ printf OUT " two-token preceeding contexts where most of the errors occur:\n\n" ;
1354
+
1355
+ print_context($counts{con_bef_2}, $counts{con_pos_bef_2}, $max_con_len, $max_con_pos_len) ;
1356
+
1357
+ printf OUT " one-token following contexts where most of the errors occur:\n\n" ;
1358
+
1359
+ print_context($counts{con_aft}, $counts{con_pos_aft}, $max_con_len, $max_con_pos_len) ;
1360
+
1361
+ printf OUT " two-token following contexts where most of the errors occur:\n\n" ;
1362
+
1363
+ print_context($counts{con_aft_2}, $counts{con_pos_aft_2}, $max_con_len, $max_con_pos_len) ;
1364
+
1365
+ # ------------- Sentences
1366
+
1367
+ printf OUT " Sentence with the highest number of word errors:\n" ;
1368
+ $i = (sort { (defined($err_sent[$b]{word}) && $err_sent[$b]{word})
1369
+ <=> (defined($err_sent[$a]{word}) && $err_sent[$a]{word}) } 1 .. $sent_num)[0] ;
1370
+ printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
1371
+ printf OUT "%d head errors, %d dependency errors, %d word errors\n",
1372
+ $err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
1373
+
1374
+ printf OUT "\n\n" ;
1375
+
1376
+ printf OUT " Sentence with the highest number of head errors:\n" ;
1377
+ $i = (sort { (defined($err_sent[$b]{head}) && $err_sent[$b]{head})
1378
+ <=> (defined($err_sent[$a]{head}) && $err_sent[$a]{head}) } 1 .. $sent_num)[0] ;
1379
+ printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
1380
+ printf OUT "%d head errors, %d dependency errors, %d word errors\n",
1381
+ $err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
1382
+
1383
+ printf OUT "\n\n" ;
1384
+
1385
+ printf OUT " Sentence with the highest number of dependency errors:\n" ;
1386
+ $i = (sort { (defined($err_sent[$b]{dep}) && $err_sent[$b]{dep})
1387
+ <=> (defined($err_sent[$a]{dep}) && $err_sent[$a]{dep}) } 1 .. $sent_num)[0] ;
1388
+ printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
1389
+ printf OUT "%d head errors, %d dependency errors, %d word errors\n",
1390
+ $err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
1391
+
1392
+ #
1393
+ # Second pass, collect statistics of the frequent errors
1394
+ #
1395
+
1396
+ # filter the errors, leave the most frequent $freq_err_num errors
1397
+
1398
+ $i = 0 ;
1399
+
1400
+ $thresh = (sort {$b <=> $a} values %freq_err)[$freq_err_num-1] ;
1401
+
1402
+ foreach $err (keys %freq_err)
1403
+ {
1404
+ if ($freq_err{$err} < $thresh)
1405
+ {
1406
+ delete $freq_err{$err} ;
1407
+ }
1408
+ }
1409
+
1410
+ # in case there are several errors with the threshold count
1411
+
1412
+ $freq_err_num = scalar keys %freq_err ;
1413
+
1414
+ %err_counts = () ;
1415
+
1416
+ $eof = 0 ;
1417
+
1418
+ seek (GOLD, 0, 0) ;
1419
+ seek (SYS, 0, 0) ;
1420
+
1421
+ while (! $eof)
1422
+ { # second reading loop
1423
+
1424
+ $eof = read_sent(\@sent_gold, \@sent_sys) ;
1425
+ $sent_num++ ;
1426
+
1427
+ $word_num = scalar @sent_gold ;
1428
+
1429
+ # printf "$sent_num $word_num\n" ;
1430
+
1431
+ foreach $i_w (0 .. $word_num-1)
1432
+ { # loop on words
1433
+ ($word, $pos, $head_g, $dep_g)
1434
+ = @{$sent_gold[$i_w]}{'word', 'pos', 'head', 'dep'} ;
1435
+
1436
+ # printf "%d: %s %s %s %s\n", $i_w, $word, $pos, $head_g, $dep_g ;
1437
+
1438
+ if ((! $score_on_punct) && is_uni_punct($word))
1439
+ {
1440
+ # ignore punctuations
1441
+ next ;
1442
+ }
1443
+
1444
+ ($head_s, $dep_s) = @{$sent_sys[$i_w]}{'head', 'dep'} ;
1445
+
1446
+ $err_head = ($head_g ne $head_s) ;
1447
+ $err_dep = ($dep_g ne $dep_s) ;
1448
+
1449
+ $head_err = '-' ;
1450
+ $dep_err = '-' ;
1451
+
1452
+ if ($head_g eq '0')
1453
+ {
1454
+ $head_aft_bef_g = '0' ;
1455
+ }
1456
+ elsif ($head_g eq $i_w+1)
1457
+ {
1458
+ $head_aft_bef_g = 'e' ;
1459
+ }
1460
+ else
1461
+ {
1462
+ $head_aft_bef_g = ($head_g <= $i_w+1 ? 'b' : 'a') ;
1463
+ }
1464
+
1465
+ if ($head_s eq '0')
1466
+ {
1467
+ $head_aft_bef_s = '0' ;
1468
+ }
1469
+ elsif ($head_s eq $i_w+1)
1470
+ {
1471
+ $head_aft_bef_s = 'e' ;
1472
+ }
1473
+ else
1474
+ {
1475
+ $head_aft_bef_s = ($head_s <= $i_w+1 ? 'b' : 'a') ;
1476
+ }
1477
+
1478
+ $head_aft_bef = $head_aft_bef_g.$head_aft_bef_s ;
1479
+
1480
+ if ($err_head)
1481
+ {
1482
+ if ($head_aft_bef_s eq '0')
1483
+ {
1484
+ $head_err = 0 ;
1485
+ }
1486
+ else
1487
+ {
1488
+ $head_err = $head_s-$head_g ;
1489
+ }
1490
+ }
1491
+
1492
+ if ($err_dep)
1493
+ {
1494
+ $dep_err = $dep_g.'->'.$dep_s ;
1495
+ }
1496
+
1497
+ if (! ($err_head || $err_dep))
1498
+ {
1499
+ next ;
1500
+ }
1501
+
1502
+ # handle only the most frequent errors
1503
+
1504
+ $err = $head_err.$sep.$head_aft_bef.$sep.$dep_err ;
1505
+
1506
+ if (! exists $freq_err{$err})
1507
+ {
1508
+ next ;
1509
+ }
1510
+
1511
+ ($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) = get_context(\@sent_gold, $i_w) ;
1512
+
1513
+ $con_bef = $w_1 ;
1514
+ $con_bef_2 = $w_2.' + '.$w_1 ;
1515
+ $con_aft = $w1 ;
1516
+ $con_aft_2 = $w1.' + '.$w2 ;
1517
+
1518
+ $con_pos_bef = $p_1 ;
1519
+ $con_pos_bef_2 = $p_2.'+'.$p_1 ;
1520
+ $con_pos_aft = $p1 ;
1521
+ $con_pos_aft_2 = $p1.'+'.$p2 ;
1522
+
1523
+ @cur_err = ($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) ;
1524
+
1525
+ # printf "# %-25s %-15s %-10s %-25s %-3s %-30s\n",
1526
+ # $con_bef, $word, $pos, $con_aft, $head_err, $dep_err ;
1527
+
1528
+ @bits = (0, 0, 0, 0, 0, 0) ;
1529
+ $j = 0 ;
1530
+
1531
+ while ($j == 0)
1532
+ {
1533
+ for ($i = 0; $i <= $#bits; $i++)
1534
+ {
1535
+ if ($bits[$i] == 0)
1536
+ {
1537
+ $bits[$i] = 1 ;
1538
+ $j = 0 ;
1539
+ last ;
1540
+ }
1541
+ else
1542
+ {
1543
+ $bits[$i] = 0 ;
1544
+ $j = 1 ;
1545
+ }
1546
+ }
1547
+
1548
+ @e_bits = @cur_err ;
1549
+
1550
+ for ($i = 0; $i <= $#bits; $i++)
1551
+ {
1552
+ if (! $bits[$i])
1553
+ {
1554
+ $e_bits[$i] = '*' ;
1555
+ }
1556
+ }
1557
+
1558
+ # include also the last case which is the most general
1559
+ # (wildcards for everything)
1560
+ $err_counts{$err}{join($sep, @e_bits)}++ ;
1561
+
1562
+ }
1563
+
1564
+ } # loop on words
1565
+ } # second reading loop
1566
+
1567
+ printf OUT "\n\n" ;
1568
+ printf OUT " Specific errors, %d most frequent errors:", $freq_err_num ;
1569
+ printf OUT "\n %s\n", '=' x 41 ;
1570
+
1571
+
1572
+ # deleting local contexts which are too general
1573
+
1574
+ foreach $err (keys %err_counts)
1575
+ {
1576
+ foreach $loc_con (sort {$err_counts{$err}{$b} <=> $err_counts{$err}{$a}}
1577
+ keys %{$err_counts{$err}})
1578
+ {
1579
+ @cur_err = split(/\Q$sep\E/, $loc_con) ;
1580
+
1581
+ # In this loop, one or two elements of the local context are
1582
+ # replaced with '*' to make it more general. If the entry for
1583
+ # the general context has the same count it is removed.
1584
+
1585
+ foreach $i (0 .. $#cur_err)
1586
+ {
1587
+ $w1 = $cur_err[$i] ;
1588
+ if ($cur_err[$i] eq '*')
1589
+ {
1590
+ next ;
1591
+ }
1592
+ $cur_err[$i] = '*' ;
1593
+ $con1 = join($sep, @cur_err) ;
1594
+ if ( defined($err_counts{$err}{$con1}) && defined($err_counts{$err}{$loc_con})
1595
+ && ($err_counts{$err}{$con1} == $err_counts{$err}{$loc_con}))
1596
+ {
1597
+ delete $err_counts{$err}{$con1} ;
1598
+ }
1599
+ for ($j = $i+1; $j <=$#cur_err; $j++)
1600
+ {
1601
+ if ($cur_err[$j] eq '*')
1602
+ {
1603
+ next ;
1604
+ }
1605
+ $w2 = $cur_err[$j] ;
1606
+ $cur_err[$j] = '*' ;
1607
+ $con1 = join($sep, @cur_err) ;
1608
+ if ( defined($err_counts{$err}{$con1}) && defined($err_counts{$err}{$loc_con})
1609
+ && ($err_counts{$err}{$con1} == $err_counts{$err}{$loc_con}))
1610
+ {
1611
+ delete $err_counts{$err}{$con1} ;
1612
+ }
1613
+ $cur_err[$j] = $w2 ;
1614
+ }
1615
+ $cur_err[$i] = $w1 ;
1616
+ }
1617
+ }
1618
+ }
1619
+
1620
+ # Leaving only the topmost local contexts for each error
1621
+
1622
+ foreach $err (keys %err_counts)
1623
+ {
1624
+ $thresh = (sort {$b <=> $a} values %{$err_counts{$err}})[$spec_err_loc_con-1] || 0 ;
1625
+
1626
+ # of the threshold is too low, take the 2nd highest count
1627
+ # (the highest may be the total which is the generic case
1628
+ # and not relevant for printing)
1629
+
1630
+ if ($thresh < 5)
1631
+ {
1632
+ $thresh = (sort {$b <=> $a} values %{$err_counts{$err}})[1] ;
1633
+ }
1634
+
1635
+ foreach $loc_con (keys %{$err_counts{$err}})
1636
+ {
1637
+ if ($err_counts{$err}{$loc_con} < $thresh)
1638
+ {
1639
+ delete $err_counts{$err}{$loc_con} ;
1640
+ }
1641
+ else
1642
+ {
1643
+ if ($loc_con ne join($sep, ('*', '*', '*', '*', '*', '*')))
1644
+ {
1645
+ $loc_con_err_counts{$loc_con}{$err} = $err_counts{$err}{$loc_con} ;
1646
+ }
1647
+ }
1648
+ }
1649
+ }
1650
+
1651
+ # printing an error summary
1652
+
1653
+ # calculating the context field length
1654
+
1655
+ $max_word_spec_len= length('word') ;
1656
+ $max_con_aft_len = length('word') ;
1657
+ $max_con_bef_len = length('word') ;
1658
+ $max_con_pos_len = length('CPOS') ;
1659
+
1660
+ foreach $err (keys %err_counts)
1661
+ {
1662
+ foreach $loc_con (sort keys %{$err_counts{$err}})
1663
+ {
1664
+ ($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
1665
+ split(/\Q$sep\E/, $loc_con) ;
1666
+
1667
+ $l = uni_len($word) ;
1668
+ if ($l > $max_word_spec_len)
1669
+ {
1670
+ $max_word_spec_len = $l ;
1671
+ }
1672
+
1673
+ $l = uni_len($con_bef) ;
1674
+ if ($l > $max_con_bef_len)
1675
+ {
1676
+ $max_con_bef_len = $l ;
1677
+ }
1678
+
1679
+ $l = uni_len($con_aft) ;
1680
+ if ($l > $max_con_aft_len)
1681
+ {
1682
+ $max_con_aft_len = $l ;
1683
+ }
1684
+
1685
+ if (length($con_pos_aft) > $max_con_pos_len)
1686
+ {
1687
+ $max_con_pos_len = length($con_pos_aft) ;
1688
+ }
1689
+
1690
+ if (length($con_pos_bef) > $max_con_pos_len)
1691
+ {
1692
+ $max_con_pos_len = length($con_pos_bef) ;
1693
+ }
1694
+ }
1695
+ }
1696
+
1697
+ $err_counter = 0 ;
1698
+
1699
+ foreach $err (sort {$freq_err{$b} <=> $freq_err{$a}} keys %freq_err)
1700
+ {
1701
+
1702
+ ($head_err, $head_aft_bef, $dep_err) = split(/\Q$sep\E/, $err) ;
1703
+
1704
+ $err_counter++ ;
1705
+ $err_desc{$err} = sprintf("%2d. ", $err_counter).
1706
+ describe_err($head_err, $head_aft_bef, $dep_err) ;
1707
+
1708
+ # printf OUT " %-3s %-30s %d\n", $head_err, $dep_err, $freq_err{$err} ;
1709
+ printf OUT "\n" ;
1710
+ printf OUT " %s : %d times\n", $err_desc{$err}, $freq_err{$err} ;
1711
+
1712
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
1713
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1714
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1715
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1716
+
1717
+ printf OUT " %-*s | %-*s | %-*s | %s\n",
1718
+ $max_con_pos_len+$max_con_bef_len+3, ' Before',
1719
+ $max_word_spec_len+$max_pos_len+3, ' Focus',
1720
+ $max_con_pos_len+$max_con_aft_len+3, ' After',
1721
+ 'Count' ;
1722
+
1723
+ printf OUT " %-*s %-*s | %-*s %-*s | %-*s %-*s |\n",
1724
+ $max_con_pos_len, 'CPOS', $max_con_bef_len, 'word',
1725
+ $max_pos_len, 'CPOS', $max_word_spec_len, 'word',
1726
+ $max_con_pos_len, 'CPOS', $max_con_aft_len, 'word' ;
1727
+
1728
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
1729
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1730
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1731
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1732
+
1733
+ foreach $loc_con (sort {$err_counts{$err}{$b} <=> $err_counts{$err}{$a}}
1734
+ keys %{$err_counts{$err}})
1735
+ {
1736
+ if ($loc_con eq join($sep, ('*', '*', '*', '*', '*', '*')))
1737
+ {
1738
+ next ;
1739
+ }
1740
+
1741
+ $con1 = $loc_con ;
1742
+ $con1 =~ s/\*/ /g ;
1743
+
1744
+ ($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
1745
+ split(/\Q$sep\E/, $con1) ;
1746
+
1747
+ printf OUT " %-*s | %-*s | %-*s | %-*s | %-*s | %-*s | %3d\n",
1748
+ $max_con_pos_len, $con_pos_bef, $max_con_bef_len+length($con_bef)-uni_len($con_bef), $con_bef,
1749
+ $max_pos_len, $pos, $max_word_spec_len+length($word)-uni_len($word), $word,
1750
+ $max_con_pos_len, $con_pos_aft, $max_con_aft_len+length($con_aft)-uni_len($con_aft), $con_aft,
1751
+ $err_counts{$err}{$loc_con} ;
1752
+ }
1753
+
1754
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
1755
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1756
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1757
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1758
+
1759
+ }
1760
+
1761
+ printf OUT "\n\n" ;
1762
+ printf OUT " Local contexts involved in several frequent errors:" ;
1763
+ printf OUT "\n %s\n", '=' x 51 ;
1764
+ printf OUT "\n\n" ;
1765
+
1766
+ foreach $loc_con (sort {scalar keys %{$loc_con_err_counts{$b}} <=>
1767
+ scalar keys %{$loc_con_err_counts{$a}}}
1768
+ keys %loc_con_err_counts)
1769
+ {
1770
+
1771
+ if (scalar keys %{$loc_con_err_counts{$loc_con}} == 1)
1772
+ {
1773
+ next ;
1774
+ }
1775
+
1776
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
1777
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1778
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1779
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1780
+
1781
+ printf OUT " %-*s | %-*s | %-*s \n",
1782
+ $max_con_pos_len+$max_con_bef_len+3, ' Before',
1783
+ $max_word_spec_len+$max_pos_len+3, ' Focus',
1784
+ $max_con_pos_len+$max_con_aft_len+3, ' After' ;
1785
+
1786
+ printf OUT " %-*s %-*s | %-*s %-*s | %-*s %-*s \n",
1787
+ $max_con_pos_len, 'CPOS', $max_con_bef_len, 'word',
1788
+ $max_pos_len, 'CPOS', $max_word_spec_len, 'word',
1789
+ $max_con_pos_len, 'CPOS', $max_con_aft_len, 'word' ;
1790
+
1791
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
1792
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1793
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1794
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1795
+
1796
+ $con1 = $loc_con ;
1797
+ $con1 =~ s/\*/ /g ;
1798
+
1799
+ ($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
1800
+ split(/\Q$sep\E/, $con1) ;
1801
+
1802
+ printf OUT " %-*s | %-*s | %-*s | %-*s | %-*s | %-*s \n",
1803
+ $max_con_pos_len, $con_pos_bef, $max_con_bef_len+length($con_bef)-uni_len($con_bef), $con_bef,
1804
+ $max_pos_len, $pos, $max_word_spec_len+length($word)-uni_len($word), $word,
1805
+ $max_con_pos_len, $con_pos_aft, $max_con_aft_len+length($con_aft)-uni_len($con_aft), $con_aft ;
1806
+
1807
+ printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
1808
+ '-' x $max_con_pos_len, '-' x $max_con_bef_len,
1809
+ '-' x $max_pos_len, '-' x $max_word_spec_len,
1810
+ '-' x $max_con_pos_len, '-' x $max_con_aft_len ;
1811
+
1812
+ foreach $err (sort {$loc_con_err_counts{$loc_con}{$b} <=>
1813
+ $loc_con_err_counts{$loc_con}{$a}}
1814
+ keys %{$loc_con_err_counts{$loc_con}})
1815
+ {
1816
+ printf OUT " %s : %d times\n", $err_desc{$err},
1817
+ $loc_con_err_counts{$loc_con}{$err} ;
1818
+ }
1819
+
1820
+ printf OUT "\n" ;
1821
+ }
1822
+
1823
+ close GOLD ;
1824
+ close SYS ;
1825
+
1826
+ close OUT ;
examples/test_original_dcst.sh ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ domain="san"
3
+ word_path="../Documents/DCST/data/multilingual_word_embeddings/cc.sanskrit.300.new.vec"
4
+ char_path="../Documents/DCST/data/multilingual_word_embeddings/hellwig_char_embedding.128"
5
+ pos_path='../Documents/DCST/data/multilingual_word_embeddings/pos_embedding_FT.100'
6
+ declare -i num_epochs=1
7
+ declare -i word_dim=300
8
+ declare -i set_num_training_samples=500
9
+ start_time=`date +%s`
10
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
11
+ # model_path="ud_parser_san_"$current_time
12
+
13
+ # ######## Changes ###########
14
+ # ## --use_pos is removed from all the models
15
+ # ## singletions are excluded.
16
+ # ## auxiliary task epoch change to 20
17
+
18
+ touch saved_models/log.txt
19
+ ################################################################
20
+ echo "#################################################################"
21
+ echo "Currently base model in progress..."
22
+ echo "#################################################################"
23
+ python examples/GraphParser.py --dataset ud --domain $domain --rnn_mode LSTM \
24
+ --num_epochs $num_epochs --batch_size 16 --hidden_size 512 --arc_space 512 \
25
+ --arc_tag_space 128 --num_layers 3 --num_filters 100 --use_char \
26
+ --word_dim $word_dim --pos_dim 100 --initializer xavier --opt adam \
27
+ --learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 \
28
+ --epsilon 1e-6 --p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst \
29
+ --punct_set '.' '``' ':' ',' --word_embedding fasttext \
30
+ --set_num_training_samples $set_num_training_samples \
31
+ --word_path $word_path --char_path $char_path --pos_path $pos_path --pos_embedding random --char_embedding random --char_dim 100 \
32
+ --model_path saved_models/$model_path 2>&1 | tee saved_models/log.txt
33
+ mv saved_models/log.txt saved_models/$model_path/log.txt
34
+
35
+
36
+ ## model_path="ud_parser_san_2020.03.03-12.24.41"
37
+
38
+ ####### self-training ###########
39
+ ## 'number_of_children' 'relative_pos_based' 'distance_from_the_root'
40
+
41
+ ####### Multitask setup #########
42
+ ## 'Multitask_label_predict' 'Multitask_case_predict' 'Multitask_POS_predict' 'Multitask_coarse_predict'
43
+ ## 'predict_ma_tag_of_modifier' 'add_label' 'predict_case_of_modifier'
44
+
45
+ ####### Other Tasks #############
46
+ ## # 'add_head_ma' 'add_head_coarse_pos' 'predict_coarse_of_modifier'
47
+
48
+ for task in 'Multitask_label_predict' 'Multitask_case_predict' 'Multitask_POS_predict' 'number_of_children' 'relative_pos_based' 'distance_from_the_root' ; do
49
+ echo "#################################################################"
50
+ echo "Currently $task in progress..."
51
+ echo "#################################################################"
52
+ touch saved_models/log.txt
53
+ python examples/SequenceTagger.py --dataset ud --domain $domain --task $task \
54
+ --rnn_mode LSTM --num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
55
+ --tag_space 128 --num_layers 3 --num_filters 100 --use_char \
56
+ --pos_dim 100 --initializer xavier --opt adam --learning_rate 0.002 --decay_rate 0.5 \
57
+ --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 --p_rnn 0.33 0.33 \
58
+ --p_in 0.33 --p_out 0.33 --punct_set '.' '``' ':' ',' \
59
+ --word_dim $word_dim --word_embedding fasttext --word_path $word_path \
60
+ --parser_path saved_models/$model_path/ \
61
+ --use_unlabeled_data --char_path $char_path --pos_path $pos_path --pos_embedding random --char_embedding random --char_dim 100 \
62
+ --model_path saved_models/$model_path/$task/ 2>&1 | tee saved_models/log.txt
63
+ mv saved_models/log.txt saved_models/$model_path/$task/log.txt
64
+ done
65
+
66
+ echo "#################################################################"
67
+ echo "Currently final model in progress..."
68
+ echo "#################################################################"
69
+ touch saved_models/log.txt
70
+ python examples/GraphParser.py --dataset ud --domain $domain --rnn_mode LSTM \
71
+ --num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
72
+ --arc_space 512 --arc_tag_space 128 --num_layers 3 --num_filters 100 --use_char \
73
+ --word_dim $word_dim --pos_dim 100 --initializer xavier --opt adam \
74
+ --learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 \
75
+ --p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst \
76
+ --punct_set '.' '``' ':' ',' --word_embedding fasttext \
77
+ --word_path $word_path --gating --num_gates 4 \
78
+ --char_path $char_path --pos_path $pos_path --pos_embedding random --char_embedding random --char_dim 100 \
79
+ --load_sequence_taggers_paths saved_models/$model_path/Multitask_case_predict/domain_$domain.pt \
80
+ saved_models/$model_path/Multitask_POS_predict/domain_$domain.pt \
81
+ saved_models/$model_path/Multitask_label_predict/domain_$domain.pt \
82
+ --model_path saved_models/$model_path/final_ensembled_multi/ 2>&1 | tee saved_models/log.txt
83
+ mv saved_models/log.txt saved_models/$model_path/final_ensembled_multi/log.txt
84
+ end_time=`date +%s`
85
+ echo execution time was `expr $end_time - $start_time` s.
86
+
87
+
88
+ echo "#################################################################"
89
+ echo "Currently final model in progress..."
90
+ echo "#################################################################"
91
+ touch saved_models/log.txt
92
+ python examples/GraphParser.py --dataset ud --domain $domain --rnn_mode LSTM \
93
+ --num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
94
+ --arc_space 512 --arc_tag_space 128 --num_layers 3 --num_filters 100 --use_char \
95
+ --word_dim $word_dim --pos_dim 100 --initializer xavier --opt adam \
96
+ --learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 \
97
+ --p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst \
98
+ --punct_set '.' '``' ':' ',' --word_embedding fasttext \
99
+ --word_path $word_path --gating --num_gates 7 \
100
+ --char_path $char_path --pos_path $pos_path --pos_embedding random --char_embedding random --char_dim 100 \
101
+ --load_sequence_taggers_paths saved_models/$model_path/Multitask_case_predict/domain_$domain.pt \
102
+ saved_models/$model_path/Multitask_POS_predict/domain_$domain.pt \
103
+ saved_models/$model_path/Multitask_label_predict/domain_$domain.pt \
104
+ saved_models/$model_path/number_of_children/domain_$domain.pt \
105
+ saved_models/$model_path/relative_pos_based/domain_$domain.pt \
106
+ saved_models/$model_path/distance_from_the_root/domain_$domain.pt \
107
+ --model_path saved_models/$model_path/final_ensembled_multi_self/ 2>&1 | tee saved_models/log.txt
108
+ mv saved_models/log.txt saved_models/$model_path/final_ensembled_multi_self/log.txt
109
+ end_time=`date +%s`
110
+ echo execution time was `expr $end_time - $start_time` s.
run_san_LCM.sh ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ domain="san"
3
+
4
+ word_path="./data/multilingual_word_embeddings/cc.sanskrit.300.new.vec"
5
+ declare -i num_epochs=100
6
+ declare -i word_dim=300
7
+ declare -i set_num_training_samples=500
8
+ start_time=`date +%s`
9
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
10
+ model_path="ud_parser_san_"$current_time
11
+
12
+
13
+ echo "#################################################################"
14
+ echo "Currently BiAFFINE model training in progress..."
15
+ echo "#################################################################"
16
+ python examples/GraphParser.py --dataset ud --domain $domain --rnn_mode LSTM \
17
+ --num_epochs $num_epochs --batch_size 16 --hidden_size 512 --arc_space 512 \
18
+ --arc_tag_space 128 --num_layers 2 --num_filters 100 --use_char \
19
+ --set_num_training_samples $set_num_training_samples \
20
+ --word_dim $word_dim --char_dim 100 --pos_dim 100 --initializer xavier --opt adam \
21
+ --learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 \
22
+ --epsilon 1e-6 --p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst \
23
+ --punct_set '.' '``' ':' ',' --word_embedding fasttext --char_embedding random --pos_embedding random --word_path $word_path \
24
+ --model_path saved_models/$model_path 2>&1 | tee saved_models/base_log.txt
25
+
26
+ mv saved_models/base_log.txt saved_models/$model_path/base_log.txt
27
+
28
+ ##################################################################
29
+ ## Pretrainig Step
30
+ ## For DCST setting set tasks as : 'number_of_children' 'relative_pos_based' 'distance_from_the_root'
31
+ ## For LCM set tasks as: 'Multitask_case_predict' 'Multitask_POS_predict' 'add_label'
32
+ for task in 'Multitask_case_predict' 'Multitask_POS_predict' 'add_label'; do
33
+ touch saved_models/$model_path/log.txt
34
+ echo "#################################################################"
35
+ echo "Currently $task in progress..."
36
+ echo "#################################################################"
37
+ python examples/SequenceTagger.py --dataset ud --domain $domain --task $task \
38
+ --rnn_mode LSTM --num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
39
+ --tag_space 128 --num_layers 2 --num_filters 100 --use_char --char_dim 100 \
40
+ --pos_dim 100 --initializer xavier --opt adam --learning_rate 0.002 --decay_rate 0.5 \
41
+ --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 --p_rnn 0.33 0.33 \
42
+ --p_in 0.33 --p_out 0.33 --punct_set '.' '``' ':' ',' \
43
+ --word_dim $word_dim --word_embedding fasttext --word_path $word_path --pos_embedding random \
44
+ --parser_path saved_models/$model_path/ \
45
+ --use_unlabeled_data --use_labeled_data --char_embedding random \
46
+ --model_path saved_models/$model_path/$task/ 2>&1 | tee saved_models/$model_path/log.txt
47
+ mv saved_models/$model_path/log.txt saved_models/$model_path/$task/log.txt
48
+ done
49
+
50
+ # ########################################################################
51
+
52
+ echo "#################################################################"
53
+ echo "Final Parsing Model Integrated with Pretrained Encoders of Auxiliary tasks..."
54
+ echo "#################################################################"
55
+ touch saved_models/$model_path/log.txt
56
+ python examples/GraphParser.py --dataset ud --domain $domain --rnn_mode LSTM \
57
+ --num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
58
+ --arc_space 512 --arc_tag_space 128 --num_layers 2 --num_filters 100 --use_char \
59
+ --word_dim $word_dim --char_dim 100 --pos_dim 100 --initializer xavier --opt adam \
60
+ --learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 \
61
+ --p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst --pos_embedding random \
62
+ --punct_set '.' '``' ':' ',' --word_embedding fasttext --char_embedding random --word_path $word_path \
63
+ --gating --num_gates 4 \
64
+ --set_num_training_samples $set_num_training_samples \
65
+ --load_sequence_taggers_paths saved_models/$model_path/add_label/domain_$domain.pt \
66
+ saved_models/$model_path/Multitask_case_predict/domain_$domain.pt \
67
+ saved_models/$model_path/Multitask_POS_predict/domain_$domain.pt \
68
+ --model_path saved_models/$model_path/final_ensembled_BiAFF_LCM 2>&1 | tee saved_models/$model_path/log.txt
69
+ mv saved_models/$model_path/log.txt saved_models/$model_path/final_ensembled_BiAFF_LCM/log.txt
70
+
71
+ end_time=`date +%s`
72
+ echo execution time was `expr $end_time - $start_time` s.
73
+
utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from . import io_
2
+ from . import nn
3
+ from . import load_word_embeddings
4
+ from . import nlinalg
5
+ from . import models
6
+
7
+ __version__ = "0.1.dev1"
utils/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (290 Bytes). View file
 
utils/__pycache__/load_word_embeddings.cpython-37.pyc ADDED
Binary file (2.63 kB). View file
 
utils/io_/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .alphabet import *
2
+ from .instance import *
3
+ from .logger import *
4
+ from .writer import *
5
+ from . import prepare_data
utils/io_/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (228 Bytes). View file
 
utils/io_/__pycache__/alphabet.cpython-37.pyc ADDED
Binary file (5.35 kB). View file
 
utils/io_/__pycache__/instance.cpython-37.pyc ADDED
Binary file (1.09 kB). View file
 
utils/io_/__pycache__/logger.cpython-37.pyc ADDED
Binary file (545 Bytes). View file
 
utils/io_/__pycache__/prepare_data.cpython-37.pyc ADDED
Binary file (10.2 kB). View file
 
utils/io_/__pycache__/reader.cpython-37.pyc ADDED
Binary file (2.37 kB). View file
 
utils/io_/__pycache__/rearrange_splits.cpython-37.pyc ADDED
Binary file (2.79 kB). View file
 
utils/io_/__pycache__/seeds.cpython-37.pyc ADDED
Binary file (358 Bytes). View file
 
utils/io_/__pycache__/write_extra_labels.cpython-37.pyc ADDED
Binary file (36.3 kB). View file
 
utils/io_/__pycache__/writer.cpython-37.pyc ADDED
Binary file (2.13 kB). View file
 
utils/io_/alphabet.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects.
3
+ """
4
+ import json
5
+ import os
6
+ from .logger import get_logger
7
+
8
+ class Alphabet(object):
9
+ def __init__(self, name, defualt_value=False, keep_growing=True, singleton=False):
10
+ self.__name = name
11
+
12
+ self.instance2index = {}
13
+ self.instances = []
14
+ self.default_value = defualt_value
15
+ self.offset = 1 if self.default_value else 0
16
+ self.keep_growing = keep_growing
17
+ self.singletons = set() if singleton else None
18
+
19
+ # Index 0 is occupied by default, all else following.
20
+ self.default_index = 0 if self.default_value else None
21
+
22
+ self.next_index = self.offset
23
+
24
+ self.logger = get_logger("Alphabet")
25
+
26
+ def add(self, instance):
27
+ if instance not in self.instance2index and instance != '<_UNK>':
28
+ self.instances.append(instance)
29
+ self.instance2index[instance] = self.next_index
30
+ self.next_index += 1
31
+
32
+ def add_singleton(self, id):
33
+ if self.singletons is None:
34
+ raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
35
+ else:
36
+ self.singletons.add(id)
37
+
38
+ def add_singletons(self, ids):
39
+ if self.singletons is None:
40
+ raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
41
+ else:
42
+ self.singletons.update(ids)
43
+
44
+ def is_singleton(self, id):
45
+ if self.singletons is None:
46
+ raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
47
+ else:
48
+ return id in self.singletons
49
+
50
+ def get_index(self, instance):
51
+ try:
52
+ return self.instance2index[instance]
53
+ except KeyError:
54
+ if self.keep_growing:
55
+ index = self.next_index
56
+ self.add(instance)
57
+ return index
58
+ else:
59
+ if self.default_value:
60
+ return self.default_index
61
+ else:
62
+ raise KeyError("instance not found: %s" % instance)
63
+
64
+ def get_instance(self, index):
65
+ if self.default_value and index == self.default_index:
66
+ # First index is occupied by the wildcard element.
67
+ return "<_UNK>"
68
+ else:
69
+ try:
70
+ return self.instances[index - self.offset]
71
+ except IndexError:
72
+ raise IndexError("unknown index: %d" % index)
73
+
74
+ def size(self):
75
+ return len(self.instances) + self.offset
76
+
77
+ def singleton_size(self):
78
+ return len(self.singletons)
79
+
80
+ def items(self):
81
+ return self.instance2index.items()
82
+
83
+ def keys(self):
84
+ return self.instance2index.keys()
85
+
86
+ def values(self):
87
+ return self.instance2index.values()
88
+
89
+ def token_in_alphabet(self, token):
90
+ return token in set(self.instance2index.keys())
91
+
92
+ def enumerate_items(self, start):
93
+ if start < self.offset or start >= self.size():
94
+ raise IndexError("Enumerate is allowed between [%d : size of the alphabet)" % self.offset)
95
+ return zip(range(start, len(self.instances) + self.offset), self.instances[start - self.offset:])
96
+
97
+ def close(self):
98
+ self.keep_growing = False
99
+
100
+ def open(self):
101
+ self.keep_growing = True
102
+
103
+ def get_content(self):
104
+ if self.singletons is None:
105
+ return {"instance2index": self.instance2index, "instances": self.instances}
106
+ else:
107
+ return {"instance2index": self.instance2index, "instances": self.instances,
108
+ "singletions": list(self.singletons)}
109
+
110
+ def __from_json(self, data):
111
+ self.instances = data["instances"]
112
+ self.instance2index = data["instance2index"]
113
+ if "singletions" in data:
114
+ self.singletons = set(data["singletions"])
115
+ else:
116
+ self.singletons = None
117
+
118
+ def save(self, output_directory, name=None):
119
+ """
120
+ Save both alhpabet records to the given directory.
121
+ :param output_directory: Directory to save model and weights.
122
+ :param name: The alphabet saving name, optional.
123
+ :return:
124
+ """
125
+ saving_name = name if name else self.__name
126
+ try:
127
+ if not os.path.exists(output_directory):
128
+ os.makedirs(output_directory)
129
+
130
+ json.dump(self.get_content(),
131
+ open(os.path.join(output_directory, saving_name + ".json"), 'w'), indent=4)
132
+ except Exception as e:
133
+ self.logger.warn("Alphabet is not saved: %s" % repr(e))
134
+
135
+ def load(self, input_directory, name=None):
136
+ """
137
+ Load model architecture and weights from the give directory. This allow we use old saved_models even the structure
138
+ changes.
139
+ :param input_directory: Directory to save model and weights
140
+ :return:
141
+ """
142
+ loading_name = name if name else self.__name
143
+ filename = os.path.join(input_directory, loading_name + ".json")
144
+ f = json.load(open(filename))
145
+ self.__from_json(f)
146
+ self.next_index = len(self.instances) + self.offset
147
+ self.keep_growing = False
utils/io_/coarse_to_ma_dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Noun": ["acc. pl. m.", "acc. pl. *", "inst. sg. m.", "abl. du. f.", "nom. sg. *.", "loc. pl. n.", "loc. du. f.", "nom. sg. m", "g. sg. *.", "i. sg. m", "g. pl. *", "nom. pl. n.", "g. du. f.", "acc. du. f.", "loc. du. m.", "nom. sg. *", "nom. sg. m.", "loc. pl. m.", "inst. pl. n.", "g. du. n.", "nom. du. f.", "inst. pl. f.", "acc. sg. n", "abl.sg.m", "loc. pl. *", "abl. pl. n.", "i. pl. n.", "dat. du. m.", "abl. pl. m.", "inst. sg. n.", "acc. pl. n.", "nom. du. *.", "inst. du. n.", "nom. pl. *", "loc. du. *.", "abl. du. n.", "nom. pl. *.", "i. sg. m.", "dat. sg. f.", "inst. sg. *.", "abl. sg. m.", "acc. sg. *", "dat. sg. *", "dat. pl. *", "dat. du. n.", "i. pl. m.", "inst. du. m.", "nom. sg. f.", "g. sg. *", "acc. sg. *.", "i. sg. n.", "i. pl. *", "loc. du. n.", "abl. sg. *", "loc. sg. m.", "nom.sg.m", "i. sg. n..", "abl. sg. f.", "nom. sg. sg.", "dat. du. f.", "nom. sg. n.", "loc. sg. f.", "acc. du. *", "acc. pl. *.", "dat. sg. m.", "nom. sg. n", "nom. pl. f.", "i. pl. f.", "acc. du. m.", "inst. sg. f.", "g. sg. n.", "abl. sg. n", "inst. pl. m.", "loc. pl. f.", "acc. sg. f.", "g. sg. f.", "voc. du. m.", "abl. sg. n.", "voc. sg. *", "dat. sg. n.", "i. sg. *", "acc. du. n.", "g. sg. m.", "loc. sg. n", "dat. pl. m.", "nom. du. *", "nom. du. n.", "acc. sg. m.", "nom. pl. m.", "acc. pl. f.", "acc. sg. f", "g. pl. m.", "voc. du. f.", "g. pl. *.", "voc. du. n.", "dat. pl. f.", "nom. du. m.", "abl. pl. f.", "dat. pl. n.", "dat. sg. m", "loc. sg. *.", "g. pl. n.", "loc. sg. n.", "abl. sg. m", "g. du. m.", "loc. sg. *", "voc. sg. m.", "voc. pl. m.", "voc. sg. n.", "acc. sg. n.", "g. pl. f.", "i. sg. f.", "abl. du. m.", "voc. sg. f.", "voc. pl. f."], "FV": ["impft. [2] ac. du. 3", "imp. [1] ac. du. 1", "imp. [3] ac. pl. 2", "ca. pr. md. sg. 1", "ca. impft. ac. sg. 1", "pp. . pl. ", "imp. [4] ac. sg. 2", "impft . ps. pl. 3", "pr. ac. sg. 1", "pp. . sg. ", "pr. [6] ac. pl. 3", "imp. [1] ac. pl. 2", "cond. ac. sg. 3", "pft. ac. sg. 2", "impft. [6] md. du. 2", "pr. [7] md. sg. 3", "pr. ps. sg. 2", "impft. [1] md. pl. 3", "ca. impft. ac. pl. 1", "pr. [10] ac. sg. 3", "pr. [5] ac. pl. 3", "imp. md. sg. 3", "fut. ac. du. 1", "imp. [10] ac. pl. 2", "impft. [1] md. sg. 3", "impft. [6] ac. sg. 2", "pr. [6] ac. sg. 1", "ca. impft. ac. pl. 3", "opt. [8] ac. sg. 3", "des. pr. md. du. 1", "ca. imp. ac. pl. 1", "pr. [1] md. pl. 3", "aor. [5] ac. sg. 2", "pfp. md. sg. 3", "imp. [2] md. sg. 1", "impft. [3] ac. sg. 3", "fut. ac. pl. 1", "pr. [vn.] ac. pl. 3", "ca. imp. ps. sg. 3", "impft. [10] ac. sg. 3", "imp. ac. pl. 3", "impft. [8] ac. pl. 3", "impft. [10] ac. sg. 2", "impft. [1] ac. pl. 3", "pr. [10] ac. du. 3", "impft. [10] ac. du. 3", "pp. . sg. 1", "pr. [4] md. sg. 3", "opt. [1] ac. pl. 2", "pfp. md. sg. 1", "pr. [2] ac. sg. 3", "opt. ac. pl. 1", "pp. ac. sg. 2", "pr. [5] ac. du. 3", "pr. md. sg. ", "pr. [8] md. sg. 3", "ca. imp. ac. sg. 2", "imp. ac. pl. 1", "imp. ps. sg. 2", "pft. md. pl. ", "ca. impft. md. sg. 3", "pr. [6] md. sg. 3", "opt. [1] ac. sg. 3", "pft. ac. sg. 1", "pft. ac. du. 3", "pr. [vn.] ac. sg. 3", "imp. ps. pl. 3", "opt. ps. sg. 3", "pr. ac. pl. 1", "pft. md. du. 3", "pr. [1] md. sg. 1", "imp. [1] ac. sg. 3", "pr. [3] ac. sg. 1", "pft. ps. sg. 3", "ca. pr. ac. du. 3", "impft. [9] ac. sg. 3", "impft . ac. sg. 1", "inj. [2] ac. sg. 3", "imp. [1] ac. sg. 2", "pr. [6] ac. du. 3", "fut. ac. du. 3", "imp. [vn.] ac. sg. 2", "imp. ac. sg. 2", "pr. ps. pl. 3", "pr. [6] ac. pl. 2", "ppr. . pl. ", "ppr. . sg. ", "pr. md. sg. 1", "imp. [9] ac. sg. 3", "impft. md. sg. ", "aor. [2] ac. sg. 3", "impft. md. sg. 2", "pr. md. sg. 3", "impft. [2] ac. sg. 1", "fut. ac. sg. 2", "opt. [2] ac. sg. 3", "imp. [8] ac. sg. 3", "imp. [10] ac. sg. 2", "pr. [4] ac. sg. 3", "pr. [8] ac. sg. 3", "fut. md. sg. 1", "opt. [8] md. sg. 3", "ca. pr. md. sg. 3", "aor. [5] ac. sg. 3", "impft. [1] md. sg. 2", "pr. [1] md. du. 3", "cond. ac. pl. 3", "imp. [4] ac. pl. 2", "impft. [1] ac. du. 1", "impft. [4] ac. pl. 3", "pr. [3] ac. pl. 3", "impft. [1] md. sg. 1", "pr. md. . 3", "imp. [1] ac. du. 3", "pr. ac. du. 3", "pr. [9] ac. sg. 2", "pr. [2] ac. du. 1", "aor. [4] ac. du. 2", "pr. [1] ac. sg. 1", "opt. [5] ac. sg. 3", "opt. [10] ac. pl. 2", "imp. [10] ac. du. 3", "aor. [1] ac. sg. 1", "pr. [10] ac. pl. 3", "opt. [2] ac. pl. 3", "opt. [1] ac. du. 3", "pr. [4] md. pl. 3", "impft. ac. sg. 2", "impft . ac. du. 3", "cond. ac. sg. 2", "pr. . sg. 3", "impft. [2] ac. sg. 3", "opt. [6] ac. sg. 3", "pr. [5] ac. sg. 3", "pft. ps. du. 3", "des. imp. md. du. 3", "impft . ac. sg. 3", "impft . ac. sg. ", "pr. md. sg. 2", "impft. [vn.] ac. sg. 3", "pr. [2] ac. sg. 2", "pr. [6] ac. sg. 3", "pr. ac. sg. 2", "pr. [7] ac. pl. 3", "imp. [3] ac. sg. 2", "opt. [2] ac. sg. 1", "pr. [1] md. pl. 2", "pft. md. pl. 3", "pr. [1] ac. sg. 3", "pr. [1] ac. pl. 1", "pr. [3] ac. du. 3", "impft. md. du. 3", "opt. [10] ac. du. 2", "imp. ac. du. 2", "opt. [1] ac. pl. 3", "impft. [vn.] md. sg. 3", "impft . ac. sg. 2", "aor. [4] ac. sg. 3", "ca. fut. ac. sg. 1", "imp. ac. sg. 3", "fut. ac. pl. 2", "impft. md. pl. 3", "ca. fut. ac. sg. 3", "opt. [2] ac. sg. 2", "impft. ac. pl. 3", "pr. [10] ac. sg. 1", "impft. md. sg. 3", "pr. [6] ac. sg. 2", "ca. pr. ps. sg. 3", "impft. ac. sg. 3", "imp. [4] ac. sg. 3", "pr. ps. sg. 1", "pr. [2] md. sg. 3", "imp. [1] ac. du. 2", "impft. [4] ac. sg. 3", "ca. opt. ac. sg. 3", "ca. pr. ac. sg. 3", "opt. ac. pl. 3", "imp. [1] md. sg. 3", "pr. [10] ac. sg. 2", "ca. pr. ps. du. 3", "opt. [5] ac. sg. 1", "impft. ps. pl. 3", "pr. [5] md. sg. 3", "pft. ps. pl. 3", "pft. ac. pl. 2", "ca. opt. ps. sg. 3", "impft. [6] ac. pl. 1", "imp. md. pl. 3", "pr. ps. du. 3", "ca. pr. ac. sg. 1", "opt. md. sg. 3", "fut. md. sg. 3", "impft . ac. pl. 3", "pr. [1] md. sg. 2", "pr. md. sg. 3", "opt. [8] ac. pl. 3", "opt. ac. du. 3", "impft. [2] ac. pl. 3", "imp. [2] ac. sg. 2", "imp. ac. pl. 2", "opt. [10] ac. sg. 3", "pft. md. sg. 1", "pr. [8] ac. sg. 1", "imp. [1] md. sg. 2", "pr. [vn.] md. sg. 3", "ca. pr. ps. sg. 1", "ca. imp. ac. sg. 3", "imp. [2] ac. du. 2", "imp. [6] ac. sg. 3", "opt. [1] ac. sg. 1", "ca. opt. ac. sg. 1", "imp. md. sg. 2", "pft. md. sg. 3", "imp. [2] ac. sg. 3", "impft. [1] ac. sg. 3", "ca. pr. ac. sg. 2", "pr. [1] md. pl. 1", "per. fut. md. sg. 3", "impft. [2] ac. sg. 2", "opt. ac. sg. 2", "imp. [3] ac. sg. 3", "imp. [3] ac. du. 2", "imp. [vn.] ac. pl. 3", "ca. pr. ac. pl. 3", "pr. [1] ac. sg. 2", "imp. [1] ac. pl. 3", "impft. [6] md. pl. 3", "per. fut. ac. sg. 3", "pr. md. pl. 3", "impft. [4] md. pl. 3", "opt. [vn.] ac. sg. 3", "aor. [1] ac. sg. 3", "impft. [4] md. sg. 3", "impft . md. sg. 3", "pr. [1] ac. pl. 3", "pft. ac. pl. 3", "opt. ac. sg. 3", "impft. ps. sg. 3", "impft. [6] ac. sg. 3", "pr. [1] ac. du. 3", "imp. [4] md. sg. 3", "per. fut. ac. sg. 3", "impft. [1] ac. sg. 1", "pr. [2] ac. du. 3", "pr. [2] ac. pl. 1", "impft. [1] ac. du. 3", "fut. ac. du. 2", "imp. [9] ac. sg. 2", "opt. [8] ac. sg. 2", "pr. [9] ac. sg. 3", "des. pr. md. pl. 3", "imp. [8] ac. sg. 2", "inj. [4] md. sg. 3", "fut. ac. pl. 3", "pr. [4] ac. pl. 3", "imp. md. sg. 1", "imp. [6] ac. sg. 2", "imp. ps. sg. 3", "imp. [vn.] ac. sg. 3", "pfp. . pl. ", "impft. [3] ac. pl. 3", "impft. [4] ac. sg. 1", "fut. ac. sg. 1", "per. fut. ac. sg. 2", "pr. [2] ac. pl. 3", "pr. [2] ac. sg. 1", "opt. md. . 3", "des. impft. ac. pl. 2", "opt. [3] ac. pl. 2", "pr. [5] ac. sg. 1", "pr. md. du. 3", "imp. [10] ac. sg. 3", "imp. [2] md. sg. 3", "pft. ac. sg. 3", "impft. [8] ac. sg. 3", "imp. [5] ac. sg. 2", "impft. [8] ac. du. 3", "impft . ps. sg. 3", "ca. opt. ac. sg. 2", "per. fut. ac. pl. 3", "pr. [9] ac. sg. 1", "impft . ps. du. 3", "aor. [1] ac. pl. 3", "pr. [3] ac. sg. 3", "opt. [3] ac. sg. 1", "impft. ac. sg. ", "aor. [1] ac. du. 3", "pr. [1] md. sg. 3", "impft. [1] md. pl. 1", "des. pr. md. sg. 3", "opt. [8] ac. sg. 1", "imp. [4] ac. pl. 3", "imp. [4] ac. du. 3", "pr. ac. sg. 3", "opt. ac. du. 1", "impft. [10] ac. pl. 3", "pr. ac. pl. 3", "pfp. . sg. ", "impft . ac. pl. 2", "fut. ac. sg. 3", "fut. md. pl. 3", "impft. [5] ac. sg. 3", "impft. [1] ac. sg. 2", "pr. [8] ac. pl. 3", "pr. [8] ac. du. 3", "opt. [8] ac. pl. 1", "ca. impft. ac. sg. 3", "opt. ac. sg. 1", "impft. [5] ac. sg. 1", "pp. . du. ", "pr. md. pl. 1", "impft . md. pl. 3", "pr. ps. sg. 3", "des. impft. md. pl. 1", "opt. [10] ac. sg. 2", "impft. [8] md. sg. 3", "pr. [6] md. pl. 3", "imp. [1] md. pl. 2", "opt. [1] md. sg. 3", "opt. [3] md. sg. 3", "ca. fut. md. sg. 3", "pr. [9] ac. pl. 3"], "IV": ["pp.", "inf.", "ca. ppr. ps.", "pp. . . ", "des. ppr. ac.", "ca. pfp. [1]", "ppr. ps.", "ppr.", "abs.", "ppr. [6] ac.", "ppa.", "pfp. [3]", "pfp. [1]", "ppr. [1] ac.", "impft . ac. . ", "ppr. [1] md.", "ca. inf.", "part.", "ca. pfp. [3]", "pfp.", "ppr. [8] ac.", "ca. abs.", "ca. pp.", "ca. ppa.", "ca. ppr. ac.", "ppr. [10] ac.", "ppr. [2] ac.", "pfu. ac.", "tasil", "pfp. [2]", "ca. pfp. [2]", "ppr. [6] ac. ", "pp. [2]"], "Ind": ["prep.", "ind.", "conj."], "adv": ["adv."]}
utils/io_/convert_ud_to_onto_format.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ # /home/jivnesh/Documents/DCST/d
4
+ # 'ar','eu','cs','de','hu','lv','pl','sv','fi','ru'
5
+ # python utils/io_/convert_ud_to_onto_format.py --ud_data_path d
6
+ def write_ud_files(args):
7
+ languages_for_low_resource = ['el']
8
+
9
+ languages = sorted(list(set(languages_for_low_resource)))
10
+ splits = ['train', 'dev', 'test']
11
+ lng_to_files = dict((language, {}) for language in languages)
12
+ for language, d in lng_to_files.items():
13
+ for split in splits:
14
+ d[split] = []
15
+ lng_to_files[language] = d
16
+ sub_folders = os.listdir(args.ud_data_path)
17
+ for sub_folder in sub_folders:
18
+ folder = os.path.join(args.ud_data_path, sub_folder)
19
+ files = os.listdir(folder)
20
+ for file in files:
21
+ for language in languages:
22
+ if file.startswith(language) and file.endswith('conllu'):
23
+ for split in splits:
24
+ if split in file:
25
+ full_path = os.path.join(folder, file)
26
+ lng_to_files[language][split].append(full_path)
27
+ break
28
+
29
+ for language, split_dict in lng_to_files.items():
30
+ for split, files in split_dict.items():
31
+ if split == 'dev' and len(files) == 0:
32
+ files = split_dict['train']
33
+ print('No dev files were found, copying train files instead')
34
+ sentences = []
35
+ num_sentences = 0
36
+ for file in files:
37
+ with open(file, 'r') as file:
38
+ for line in file:
39
+ new_line = []
40
+ line = line.strip()
41
+ if len(line) == 0:
42
+ sentences.append(new_line)
43
+ num_sentences += 1
44
+ continue
45
+ tokens = line.split('\t')
46
+ if not tokens[0].isdigit():
47
+ continue
48
+ id = tokens[0]
49
+ word = tokens[1]
50
+ pos = tokens[3]
51
+ ner = tokens[5]
52
+ head = tokens[6]
53
+ arc_tag = tokens[7]
54
+ new_line = [id, word, pos, ner, head, arc_tag]
55
+ sentences.append(new_line)
56
+ print('Language: %s Split: %s Num. Sentences: %s ' % (language, split, num_sentences))
57
+ if not os.path.exists('data'):
58
+ os.makedirs('data')
59
+ write_data_path = 'data/MRL/ud_pos_ner_dp_' + split + '_' + language
60
+ print('creating %s' % write_data_path)
61
+ with open(write_data_path, 'w') as f:
62
+ for line in sentences:
63
+ f.write('\t'.join(line) + '\n')
64
+
65
+ def main():
66
+ # Parse arguments
67
+ args_ = argparse.ArgumentParser()
68
+ args_.add_argument('--ud_data_path', help='Directory path of the UD treebanks.', required=True)
69
+
70
+ args = args_.parse_args()
71
+ write_ud_files(args)
72
+
73
+ if __name__ == "__main__":
74
+ main()
utils/io_/instance.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Sentence(object):
2
+ def __init__(self, words, word_ids, char_seqs, char_id_seqs):
3
+ self.words = words
4
+ self.word_ids = word_ids
5
+ self.char_seqs = char_seqs
6
+ self.char_id_seqs = char_id_seqs
7
+
8
+ def length(self):
9
+ return len(self.words)
10
+
11
+ class NER_DependencyInstance(object):
12
+ def __init__(self, sentence, tokens_dict, ids_dict, heads):
13
+ self.sentence = sentence
14
+ self.tokens = tokens_dict
15
+ self.ids = ids_dict
16
+ self.heads = heads
17
+
18
+ def length(self):
19
+ return self.sentence.length()
utils/io_/logger.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+
4
+
5
+ def get_logger(name, level=logging.INFO, handler=sys.stdout,
6
+ formatter='%(asctime)s - %(name)s - %(levelname)s - %(message)s'):
7
+ logger = logging.getLogger(name)
8
+ logger.setLevel(logging.INFO)
9
+ formatter = logging.Formatter(formatter)
10
+ stream_handler = logging.StreamHandler(handler)
11
+ stream_handler.setLevel(level)
12
+ stream_handler.setFormatter(formatter)
13
+ logger.addHandler(stream_handler)
14
+
15
+ return logger
utils/io_/prepare_data.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import numpy as np
3
+ from .alphabet import Alphabet
4
+ from .logger import get_logger
5
+ import torch
6
+
7
+ # Special vocabulary symbols - we always put them at the end.
8
+ PAD = "_<PAD>_"
9
+ ROOT = "_<ROOT>_"
10
+ END = "_<END>_"
11
+ _START_VOCAB = [PAD, ROOT, END]
12
+
13
+ MAX_CHAR_LENGTH = 45
14
+ NUM_CHAR_PAD = 2
15
+
16
+ UNK_ID = 0
17
+ PAD_ID_WORD = 1
18
+ PAD_ID_CHAR = 1
19
+ PAD_ID_TAG = 0
20
+
21
+ NUM_SYMBOLIC_TAGS = 3
22
+
23
+ _buckets = [10, 15, 20, 25, 30, 35, 40, 50, 60, 70, 80, 90, 100, 140]
24
+
25
+ from .reader import Reader
26
+
27
+ def create_alphabets(alphabet_directory, train_paths, extra_paths=None, max_vocabulary_size=100000, embedd_dict=None,
28
+ min_occurence=1, lower_case=False):
29
+ def expand_vocab(vocab_list, char_alphabet, pos_alphabet, ner_alphabet, arc_alphabet):
30
+ vocab_set = set(vocab_list)
31
+ for data_path in extra_paths:
32
+ with open(data_path, 'r') as file:
33
+ for line in file:
34
+ line = line.strip()
35
+ if len(line) == 0:
36
+ continue
37
+
38
+ tokens = line.split('\t')
39
+ if lower_case:
40
+ tokens[1] = tokens[1].lower()
41
+ for char in tokens[1]:
42
+ char_alphabet.add(char)
43
+
44
+ word = tokens[1]
45
+ pos = tokens[2]
46
+ ner = tokens[3]
47
+ arc_tag = tokens[5]
48
+
49
+ pos_alphabet.add(pos)
50
+ ner_alphabet.add(ner)
51
+ arc_alphabet.add(arc_tag)
52
+ if embedd_dict is not None:
53
+ if word not in vocab_set and (word in embedd_dict or word.lower() in embedd_dict):
54
+ vocab_set.add(word)
55
+ vocab_list.append(word)
56
+ else:
57
+ if word not in vocab_set:
58
+ vocab_set.add(word)
59
+ vocab_list.append(word)
60
+ return vocab_list, char_alphabet, pos_alphabet, ner_alphabet, arc_alphabet
61
+
62
+ logger = get_logger("Create Alphabets")
63
+ word_alphabet = Alphabet('word', defualt_value=True, singleton=True)
64
+ char_alphabet = Alphabet('character', defualt_value=True)
65
+ pos_alphabet = Alphabet('pos', defualt_value=True)
66
+ ner_alphabet = Alphabet('ner', defualt_value=True)
67
+ arc_alphabet = Alphabet('arc', defualt_value=True)
68
+ auto_label_alphabet = Alphabet('auto_labeler', defualt_value=True)
69
+ if not os.path.isdir(alphabet_directory):
70
+ logger.info("Creating Alphabets: %s" % alphabet_directory)
71
+
72
+ char_alphabet.add(PAD)
73
+ pos_alphabet.add(PAD)
74
+ ner_alphabet.add(PAD)
75
+ arc_alphabet.add(PAD)
76
+ auto_label_alphabet.add(PAD)
77
+
78
+ char_alphabet.add(ROOT)
79
+ pos_alphabet.add(ROOT)
80
+ ner_alphabet.add(ROOT)
81
+ arc_alphabet.add(ROOT)
82
+ auto_label_alphabet.add(ROOT)
83
+
84
+ char_alphabet.add(END)
85
+ pos_alphabet.add(END)
86
+ ner_alphabet.add(END)
87
+ arc_alphabet.add(END)
88
+ auto_label_alphabet.add(END)
89
+
90
+ vocab = dict()
91
+ if isinstance(train_paths, str):
92
+ train_paths = [train_paths]
93
+ for train_path in train_paths:
94
+ with open(train_path, 'r') as file:
95
+ for line in file:
96
+ line = line.strip()
97
+ if len(line) == 0:
98
+ continue
99
+
100
+ tokens = line.split('\t')
101
+ if lower_case:
102
+ tokens[1] = tokens[1].lower()
103
+ for char in tokens[1]:
104
+ char_alphabet.add(char)
105
+
106
+ word = tokens[1]
107
+ # print(word)
108
+ pos = tokens[2]
109
+ ner = tokens[3]
110
+ arc_tag = tokens[5]
111
+
112
+ pos_alphabet.add(pos)
113
+ ner_alphabet.add(ner)
114
+ arc_alphabet.add(arc_tag)
115
+
116
+ if word in vocab:
117
+ vocab[word] += 1
118
+ else:
119
+ vocab[word] = 1
120
+
121
+ # collect singletons
122
+ singletons = set([word for word, count in vocab.items() if count <= min_occurence])
123
+
124
+ # if a singleton is in pretrained embedding dict, set the count to min_occur + c
125
+ if embedd_dict is not None:
126
+ for word in vocab.keys():
127
+ if word in embedd_dict or word.lower() in embedd_dict:
128
+ vocab[word] += min_occurence
129
+
130
+ vocab_list = sorted(vocab, key=vocab.get, reverse=True)
131
+ vocab_list = [word for word in vocab_list if vocab[word] > min_occurence]
132
+ vocab_list = _START_VOCAB + vocab_list
133
+
134
+ if extra_paths is not None:
135
+ vocab_list, char_alphabet, pos_alphabet, ner_alphabet, arc_alphabet = \
136
+ expand_vocab(vocab_list, char_alphabet, pos_alphabet, ner_alphabet, arc_alphabet)
137
+
138
+ if len(vocab_list) > max_vocabulary_size:
139
+ vocab_list = vocab_list[:max_vocabulary_size]
140
+
141
+ for word in vocab_list:
142
+ word_alphabet.add(word)
143
+ if word in singletons:
144
+ word_alphabet.add_singleton(word_alphabet.get_index(word))
145
+
146
+ word_alphabet.save(alphabet_directory)
147
+ char_alphabet.save(alphabet_directory)
148
+ pos_alphabet.save(alphabet_directory)
149
+ ner_alphabet.save(alphabet_directory)
150
+ arc_alphabet.save(alphabet_directory)
151
+ auto_label_alphabet.save(alphabet_directory)
152
+
153
+ else:
154
+ print('loading saved alphabet from %s' % alphabet_directory)
155
+ word_alphabet.load(alphabet_directory)
156
+ char_alphabet.load(alphabet_directory)
157
+ pos_alphabet.load(alphabet_directory)
158
+ ner_alphabet.load(alphabet_directory)
159
+ arc_alphabet.load(alphabet_directory)
160
+ auto_label_alphabet.load(alphabet_directory)
161
+
162
+ word_alphabet.close()
163
+ char_alphabet.close()
164
+ pos_alphabet.close()
165
+ ner_alphabet.close()
166
+ arc_alphabet.close()
167
+ auto_label_alphabet.close()
168
+
169
+ alphabet_dict = {'word_alphabet': word_alphabet, 'char_alphabet': char_alphabet, 'pos_alphabet': pos_alphabet,
170
+ 'ner_alphabet': ner_alphabet, 'arc_alphabet': arc_alphabet, 'auto_label_alphabet': auto_label_alphabet}
171
+ return alphabet_dict
172
+
173
+ def create_alphabets_for_sequence_tagger(alphabet_directory, parser_alphabet_directory, paths):
174
+ logger = get_logger("Create Alphabets")
175
+ print('loading saved alphabet from %s' % parser_alphabet_directory)
176
+ word_alphabet = Alphabet('word', defualt_value=True, singleton=True)
177
+ char_alphabet = Alphabet('character', defualt_value=True)
178
+ pos_alphabet = Alphabet('pos', defualt_value=True)
179
+ ner_alphabet = Alphabet('ner', defualt_value=True)
180
+ arc_alphabet = Alphabet('arc', defualt_value=True)
181
+ auto_label_alphabet = Alphabet('auto_labeler', defualt_value=True)
182
+
183
+ word_alphabet.load(parser_alphabet_directory)
184
+ char_alphabet.load(parser_alphabet_directory)
185
+ pos_alphabet.load(parser_alphabet_directory)
186
+ ner_alphabet.load(parser_alphabet_directory)
187
+ arc_alphabet.load(parser_alphabet_directory)
188
+ try:
189
+ auto_label_alphabet.load(alphabet_directory)
190
+ except:
191
+ print('Creating auto labeler alphabet')
192
+ auto_label_alphabet.add(PAD)
193
+ auto_label_alphabet.add(ROOT)
194
+ auto_label_alphabet.add(END)
195
+ for path in paths:
196
+ with open(path, 'r') as file:
197
+ for line in file:
198
+ line = line.strip()
199
+ if len(line) == 0:
200
+ continue
201
+ tokens = line.split('\t')
202
+ if len(tokens) > 6:
203
+ auto_label = tokens[6]
204
+ auto_label_alphabet.add(auto_label)
205
+
206
+ word_alphabet.save(alphabet_directory)
207
+ char_alphabet.save(alphabet_directory)
208
+ pos_alphabet.save(alphabet_directory)
209
+ ner_alphabet.save(alphabet_directory)
210
+ arc_alphabet.save(alphabet_directory)
211
+ auto_label_alphabet.save(alphabet_directory)
212
+ word_alphabet.close()
213
+ char_alphabet.close()
214
+ pos_alphabet.close()
215
+ ner_alphabet.close()
216
+ arc_alphabet.close()
217
+ auto_label_alphabet.close()
218
+ alphabet_dict = {'word_alphabet': word_alphabet, 'char_alphabet': char_alphabet, 'pos_alphabet': pos_alphabet,
219
+ 'ner_alphabet': ner_alphabet, 'arc_alphabet': arc_alphabet, 'auto_label_alphabet': auto_label_alphabet}
220
+ return alphabet_dict
221
+
222
+ def read_data(source_path, alphabets, max_size=None,
223
+ lower_case=False, symbolic_root=False, symbolic_end=False):
224
+ data = [[] for _ in _buckets]
225
+ max_char_length = [0 for _ in _buckets]
226
+ print('Reading data from %s' % ', '.join(source_path) if type(source_path) is list else source_path)
227
+ counter = 0
228
+ if type(source_path) is not list:
229
+ source_path = [source_path]
230
+ for path in source_path:
231
+ reader = Reader(path, alphabets)
232
+ inst = reader.getNext(lower_case=lower_case, symbolic_root=symbolic_root, symbolic_end=symbolic_end)
233
+ while inst is not None and (not max_size or counter < max_size):
234
+ counter += 1
235
+ inst_size = inst.length()
236
+ sent = inst.sentence
237
+ for bucket_id, bucket_size in enumerate(_buckets):
238
+ if inst_size < bucket_size:
239
+ data[bucket_id].append([sent.word_ids, sent.char_id_seqs, inst.ids['pos_alphabet'], inst.ids['ner_alphabet'],
240
+ inst.heads, inst.ids['arc_alphabet'], inst.ids['auto_label_alphabet']])
241
+ max_len = max([len(char_seq) for char_seq in sent.char_seqs])
242
+ if max_char_length[bucket_id] < max_len:
243
+ max_char_length[bucket_id] = max_len
244
+ break
245
+
246
+ inst = reader.getNext(lower_case=lower_case, symbolic_root=symbolic_root, symbolic_end=symbolic_end)
247
+ reader.close()
248
+ print("Total number of data: %d" % counter)
249
+ return data, max_char_length
250
+
251
+ def read_data_to_variable(source_path, alphabets, device, max_size=None,
252
+ lower_case=False, symbolic_root=False, symbolic_end=False):
253
+ data, max_char_length = read_data(source_path, alphabets,
254
+ max_size=max_size, lower_case=lower_case,
255
+ symbolic_root=symbolic_root, symbolic_end=symbolic_end)
256
+ bucket_sizes = [len(data[b]) for b in range(len(_buckets))]
257
+
258
+ data_variable = []
259
+
260
+ for bucket_id in range(len(_buckets)):
261
+ bucket_size = bucket_sizes[bucket_id]
262
+ if bucket_size <= 0:
263
+ data_variable.append((1, 1))
264
+ continue
265
+
266
+ bucket_length = _buckets[bucket_id]
267
+ char_length = min(MAX_CHAR_LENGTH, max_char_length[bucket_id] + NUM_CHAR_PAD)
268
+ wid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
269
+ cid_inputs = np.empty([bucket_size, bucket_length, char_length], dtype=np.int64)
270
+ pid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
271
+ nid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
272
+ hid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
273
+ aid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
274
+ mid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
275
+
276
+ masks = np.zeros([bucket_size, bucket_length], dtype=np.float32)
277
+ single = np.zeros([bucket_size, bucket_length], dtype=np.int64)
278
+
279
+ lengths = np.empty(bucket_size, dtype=np.int64)
280
+
281
+ for i, inst in enumerate(data[bucket_id]):
282
+ wids, cid_seqs, pids, nids, hids, aids, mids = inst
283
+ inst_size = len(wids)
284
+ lengths[i] = inst_size
285
+ # word ids
286
+ wid_inputs[i, :inst_size] = wids
287
+ wid_inputs[i, inst_size:] = PAD_ID_WORD
288
+ for c, cids in enumerate(cid_seqs):
289
+ cid_inputs[i, c, :len(cids)] = cids
290
+ cid_inputs[i, c, len(cids):] = PAD_ID_CHAR
291
+ cid_inputs[i, inst_size:, :] = PAD_ID_CHAR
292
+ # pos ids
293
+ pid_inputs[i, :inst_size] = pids
294
+ pid_inputs[i, inst_size:] = PAD_ID_TAG
295
+ # ner ids
296
+ nid_inputs[i, :inst_size] = nids
297
+ nid_inputs[i, inst_size:] = PAD_ID_TAG
298
+ # arc ids
299
+ aid_inputs[i, :inst_size] = aids
300
+ aid_inputs[i, inst_size:] = PAD_ID_TAG
301
+ # auto_label ids
302
+ mid_inputs[i, :inst_size] = mids
303
+ mid_inputs[i, inst_size:] = PAD_ID_TAG
304
+ # heads
305
+ hid_inputs[i, :inst_size] = hids
306
+ hid_inputs[i, inst_size:] = PAD_ID_TAG
307
+ # masks
308
+ masks[i, :inst_size] = 1.0
309
+ for j, wid in enumerate(wids):
310
+ if alphabets['word_alphabet'].is_singleton(wid):
311
+ single[i, j] = 1
312
+
313
+ words = torch.LongTensor(wid_inputs)
314
+ chars = torch.LongTensor(cid_inputs)
315
+ pos = torch.LongTensor(pid_inputs)
316
+ ner = torch.LongTensor(nid_inputs)
317
+ heads = torch.LongTensor(hid_inputs)
318
+ arc = torch.LongTensor(aid_inputs)
319
+ auto_label = torch.LongTensor(mid_inputs)
320
+ masks = torch.FloatTensor(masks)
321
+ single = torch.LongTensor(single)
322
+ lengths = torch.LongTensor(lengths)
323
+ words = words.to(device)
324
+ chars = chars.to(device)
325
+ pos = pos.to(device)
326
+ ner = ner.to(device)
327
+ heads = heads.to(device)
328
+ arc = arc.to(device)
329
+ auto_label = auto_label.to(device)
330
+ masks = masks.to(device)
331
+ single = single.to(device)
332
+ lengths = lengths.to(device)
333
+
334
+ data_variable.append((words, chars, pos, ner, heads, arc, auto_label, masks, single, lengths))
335
+
336
+ return data_variable, bucket_sizes
337
+
338
+ def iterate_batch(data, batch_size, device, unk_replace=0.0, shuffle=False):
339
+ data_variable, bucket_sizes = data
340
+
341
+ bucket_indices = np.arange(len(_buckets))
342
+ if shuffle:
343
+ np.random.shuffle((bucket_indices))
344
+
345
+ for bucket_id in bucket_indices:
346
+ bucket_size = bucket_sizes[bucket_id]
347
+ bucket_length = _buckets[bucket_id]
348
+ if bucket_size <= 0:
349
+ continue
350
+
351
+ words, chars, pos, ner, heads, arc, auto_label, masks, single, lengths = data_variable[bucket_id]
352
+ if unk_replace:
353
+ ones = single.data.new(bucket_size, bucket_length).fill_(1)
354
+ noise = masks.data.new(bucket_size, bucket_length).bernoulli_(unk_replace).long()
355
+ words = words * (ones - single * noise)
356
+
357
+ indices = None
358
+ if shuffle:
359
+ indices = torch.randperm(bucket_size).long()
360
+ indices = indices.to(device)
361
+ for start_idx in range(0, bucket_size, batch_size):
362
+ if shuffle:
363
+ excerpt = indices[start_idx:start_idx + batch_size]
364
+ else:
365
+ excerpt = slice(start_idx, start_idx + batch_size)
366
+ yield words[excerpt], chars[excerpt], pos[excerpt], ner[excerpt], heads[excerpt], arc[excerpt], auto_label[excerpt], \
367
+ masks[excerpt], lengths[excerpt]
368
+
369
+ def iterate_batch_rand_bucket_choosing(data, batch_size, device, unk_replace=0.0):
370
+ data_variable, bucket_sizes = data
371
+ indices_left = [set(np.arange(bucket_size)) for bucket_size in bucket_sizes]
372
+ while sum(bucket_sizes) > 0:
373
+ non_empty_buckets = [i for i, bucket_size in enumerate(bucket_sizes) if bucket_size > 0]
374
+ bucket_id = np.random.choice(non_empty_buckets)
375
+ bucket_size = bucket_sizes[bucket_id]
376
+ bucket_length = _buckets[bucket_id]
377
+
378
+ words, chars, pos, ner, heads, arc, auto_label, masks, single, lengths = data_variable[bucket_id]
379
+ min_batch_size = min(bucket_size, batch_size)
380
+ indices = torch.LongTensor(np.random.choice(list(indices_left[bucket_id]), min_batch_size, replace=False))
381
+ set_indices = set(indices.numpy())
382
+ indices_left[bucket_id] = indices_left[bucket_id].difference(set_indices)
383
+ indices = indices.to(device)
384
+ words = words[indices]
385
+ if unk_replace:
386
+ ones = single.data.new(min_batch_size, bucket_length).fill_(1)
387
+ noise = masks.data.new(min_batch_size, bucket_length).bernoulli_(unk_replace).long()
388
+ words = words * (ones - single[indices] * noise)
389
+ bucket_sizes = [len(s) for s in indices_left]
390
+ yield words, chars[indices], pos[indices], ner[indices], heads[indices], arc[indices], auto_label[indices], masks[indices], lengths[indices]
391
+
392
+
393
+ def calc_num_batches(data, batch_size):
394
+ _, bucket_sizes = data
395
+ bucket_sizes_mod_batch_size = [int(bucket_size / batch_size) + 1 if bucket_size > 0 else 0 for bucket_size in bucket_sizes]
396
+ num_batches = sum(bucket_sizes_mod_batch_size)
397
+ return num_batches
utils/io_/reader.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .instance import NER_DependencyInstance
2
+ from .instance import Sentence
3
+ from .prepare_data import ROOT, END, MAX_CHAR_LENGTH
4
+
5
+ class Reader(object):
6
+ def __init__(self, file_path, alphabets):
7
+ self.__source_file = open(file_path, 'r')
8
+ self.alphabets = alphabets
9
+
10
+ def close(self):
11
+ self.__source_file.close()
12
+
13
+ def getNext(self, lower_case=False, symbolic_root=False, symbolic_end=False):
14
+ line = self.__source_file.readline()
15
+ # skip multiple blank lines.
16
+ while len(line) > 0 and len(line.strip()) == 0:
17
+ line = self.__source_file.readline()
18
+ if len(line) == 0:
19
+ return None
20
+
21
+ lines = []
22
+ while len(line.strip()) > 0:
23
+ line = line.strip()
24
+ lines.append(line.split('\t'))
25
+ line = self.__source_file.readline()
26
+
27
+ length = len(lines)
28
+ if length == 0:
29
+ return None
30
+
31
+ heads = []
32
+ tokens_dict = {}
33
+ ids_dict = {}
34
+ for alphabet_name in self.alphabets.keys():
35
+ tokens_dict[alphabet_name] = []
36
+ ids_dict[alphabet_name] = []
37
+ if symbolic_root:
38
+ for alphabet_name, alphabet in self.alphabets.items():
39
+ if alphabet_name.startswith('char'):
40
+ tokens_dict[alphabet_name].append([ROOT, ])
41
+ ids_dict[alphabet_name].append([alphabet.get_index(ROOT), ])
42
+ else:
43
+ tokens_dict[alphabet_name].append(ROOT)
44
+ ids_dict[alphabet_name].append(alphabet.get_index(ROOT))
45
+ heads.append(0)
46
+
47
+ for tokens in lines:
48
+ chars = []
49
+ char_ids = []
50
+ if lower_case:
51
+ tokens[1] = tokens[1].lower()
52
+ for char in tokens[1]:
53
+ chars.append(char)
54
+ char_ids.append(self.alphabets['char_alphabet'].get_index(char))
55
+ if len(chars) > MAX_CHAR_LENGTH:
56
+ chars = chars[:MAX_CHAR_LENGTH]
57
+ char_ids = char_ids[:MAX_CHAR_LENGTH]
58
+ tokens_dict['char_alphabet'].append(chars)
59
+ ids_dict['char_alphabet'].append(char_ids)
60
+
61
+ word = tokens[1]
62
+ # print(word+ ' ')
63
+ pos = tokens[2]
64
+ ner = tokens[3]
65
+ head = int(tokens[4])
66
+ arc_tag = tokens[5]
67
+ if len(tokens) > 6:
68
+ auto_label = tokens[6]
69
+ tokens_dict['auto_label_alphabet'].append(auto_label)
70
+ ids_dict['auto_label_alphabet'].append(self.alphabets['auto_label_alphabet'].get_index(auto_label))
71
+ tokens_dict['word_alphabet'].append(word)
72
+ ids_dict['word_alphabet'].append(self.alphabets['word_alphabet'].get_index(word))
73
+ tokens_dict['pos_alphabet'].append(pos)
74
+ ids_dict['pos_alphabet'].append(self.alphabets['pos_alphabet'].get_index(pos))
75
+ tokens_dict['ner_alphabet'].append(ner)
76
+ ids_dict['ner_alphabet'].append(self.alphabets['ner_alphabet'].get_index(ner))
77
+ tokens_dict['arc_alphabet'].append(arc_tag)
78
+ ids_dict['arc_alphabet'].append(self.alphabets['arc_alphabet'].get_index(arc_tag))
79
+ heads.append(head)
80
+
81
+ if symbolic_end:
82
+ for alphabet_name, alphabet in self.alphabets.items():
83
+ if alphabet_name.startswith('char'):
84
+ tokens_dict[alphabet_name].append([END, ])
85
+ ids_dict[alphabet_name].append([alphabet.get_index(END), ])
86
+ else:
87
+ tokens_dict[alphabet_name] = [END]
88
+ ids_dict[alphabet_name] = [alphabet.get_index(END)]
89
+ heads.append(0)
90
+
91
+ return NER_DependencyInstance(Sentence(tokens_dict['word_alphabet'], ids_dict['word_alphabet'],
92
+ tokens_dict['char_alphabet'], ids_dict['char_alphabet']),
93
+ tokens_dict, ids_dict, heads)
utils/io_/rearrange_splits.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch import stack
3
+
4
+ def rearranging_splits(datasets, num_training_samples):
5
+ new_datasets = {}
6
+ data_splits = datasets.keys()
7
+ for split in data_splits:
8
+ if split == 'test':
9
+ new_datasets['test'] = datasets['test']
10
+
11
+ else:
12
+ num_buckets = len(datasets[split][1])
13
+ num_tensors = len(datasets[split][0][0])
14
+ num_samples = sum(datasets[split][1])
15
+ if num_samples < num_training_samples:
16
+ print("set_num_training_samples (%d) should be smaller than the actual %s size (%d)"
17
+ % (num_training_samples, split, num_samples))
18
+ new_datasets[split] = [[[[] for _ in range(num_tensors)] for _ in range(num_buckets)], []]
19
+ new_datasets['extra_' + split] = [[[[] for _ in range(num_tensors)] for _ in range(num_buckets)], []]
20
+ for split in data_splits:
21
+ if split == 'test':
22
+ continue
23
+ else:
24
+ curr_bucket_sizes = datasets[split][1]
25
+ curr_samples = datasets[split][0]
26
+ num_tensors = len(datasets[split][0][0])
27
+ curr_num_samples = sum(curr_bucket_sizes)
28
+ sample_indices_in_buckets = {}
29
+ i = 0
30
+ for bucket_idx, bucket_size in enumerate(curr_bucket_sizes):
31
+ for sample_idx in range(bucket_size):
32
+ sample_indices_in_buckets[i] = (bucket_idx, sample_idx)
33
+ i += 1
34
+ rng = np.arange(curr_num_samples)
35
+ rng = np.random.permutation(rng)
36
+ sample_indices = {}
37
+ sample_indices[split] = [sample_indices_in_buckets[key] for key in rng[:num_training_samples]]
38
+ sample_indices['extra_' + split] = [sample_indices_in_buckets[key] for key in rng[num_training_samples:]]
39
+ if len(sample_indices['extra_' + split]) == 0:
40
+ if len(sample_indices[split]) > 1:
41
+ sample_indices['extra_' + split].append(sample_indices[split].pop(-1))
42
+ else:
43
+ sample_indices['extra_' + split].append(sample_indices[split][0])
44
+
45
+ for key, indices in sample_indices.items():
46
+ for bucket_idx, sample_idx in indices:
47
+ curr_bucket = curr_samples[bucket_idx]
48
+ for tensor_idx, tensor in enumerate(curr_bucket):
49
+ new_datasets[key][0][bucket_idx][tensor_idx].append(tensor[sample_idx])
50
+ del datasets
51
+ new_splits = []
52
+ new_splits += [split for split in data_splits if split != 'test']
53
+ new_splits += ['extra_' + split for split in data_splits if split != 'test']
54
+
55
+ for split in new_splits:
56
+ for bucket_idx in range(num_buckets):
57
+ for tensor_idx in range(num_tensors):
58
+ if len(new_datasets[split][0][bucket_idx][tensor_idx]) > 0:
59
+ new_datasets[split][0][bucket_idx][tensor_idx] = stack(new_datasets[split][0][bucket_idx][tensor_idx])
60
+ else:
61
+ new_datasets[split][0][bucket_idx] = (1,1)
62
+ break
63
+ # set lengths of buckets
64
+ if new_datasets[split][0][bucket_idx] == (1,1):
65
+ new_datasets[split][1].append(0)
66
+ else:
67
+ new_datasets[split][1].append(len(new_datasets[split][0][bucket_idx][tensor_idx]))
68
+ return new_datasets
utils/io_/remove_xx.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def read_file(filename):
4
+ sentences = []
5
+ sentence = []
6
+ lengths = []
7
+ num_sentneces_to_remove = 0
8
+ num_sentences = 0
9
+ num_tokens_to_remove = 0
10
+ num_tokens = 0
11
+ with open(filename, 'r') as file:
12
+ for line in file:
13
+ line = line.strip()
14
+ if len(line) == 0:
15
+ xx_count = 0
16
+ for row in sentence:
17
+ if row[2] == 'XX':
18
+ xx_count += 1
19
+ if xx_count / len(sentence) >= 0.5:
20
+ num_sentneces_to_remove += 1
21
+ num_tokens_to_remove += len(sentence)
22
+ else:
23
+ sentences.append(sentence)
24
+ lengths.append(len(sentence))
25
+ num_sentences += 1
26
+ num_tokens += len(sentence)
27
+ sentence = []
28
+ continue
29
+ tokens = line.split('\t')
30
+ idx = tokens[0]
31
+ word = tokens[1]
32
+ pos = tokens[2]
33
+ ner = tokens[3]
34
+ arc = tokens[4]
35
+ arc_tag = tokens[5]
36
+ sentence.append((idx, word, pos, ner, arc, arc_tag))
37
+ print("removed %d sentences out of %d sentences" % (num_sentneces_to_remove, num_sentences))
38
+ print("removed %d tokens out of %d tokens" % (num_tokens_to_remove, num_tokens))
39
+ return sentences
40
+
41
+ def write_file(filename, sentences):
42
+ with open(filename, 'w') as file:
43
+ for sentence in sentences:
44
+ for row in sentence:
45
+ file.write('\t'.join([token for token in row]) + '\n')
46
+ file.write('\n')
47
+
48
+ dataset_dict = {'ontonotes': 'onto'}
49
+ datasets = ['ontonotes']
50
+ splits = ['test']
51
+ domains = ['all', 'wb']
52
+
53
+ for dataset in datasets:
54
+ for domain in domains:
55
+ for split in splits:
56
+ print('dataset: %s, domain: %s, split: %s' % (dataset, domain, split))
57
+ filemame = 'data/'+ dataset_dict[dataset] + '_pos_ner_dp_' + split + '_' + domain
58
+ sentences = read_file(filemame)
59
+ write_filename = filemame + '_without_xx'
60
+ write_file(write_filename, sentences)
utils/io_/seeds.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ seed = 0
5
+ torch.manual_seed(seed)
6
+ torch.cuda.manual_seed(seed)
7
+ torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
8
+ np.random.seed(seed) # Numpy module.
9
+ random.seed(seed) # Python random module.
10
+ torch.manual_seed(seed)
11
+ torch.backends.cudnn.benchmark = False
12
+ torch.backends.cudnn.deterministic = True
utils/io_/write_extra_labels.py ADDED
@@ -0,0 +1,1592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .prepare_data import ROOT, END
3
+ import pdb
4
+ def get_split(path):
5
+ if 'train' in path:
6
+ if 'extra_train' in path:
7
+ split = 'extra_train'
8
+ else:
9
+ split = 'train'
10
+ elif 'dev' in path:
11
+ if 'extra_dev' in path:
12
+ split = 'extra_dev'
13
+ else:
14
+ split = 'dev'
15
+ else:
16
+ split = 'test'
17
+ return split
18
+
19
+ def add_number_of_children(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
20
+ if src_domain == tgt_domain:
21
+ pred_paths = []
22
+ if use_unlabeled_data:
23
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and 'poetry' not in file and 'prose' not in file and 'extra' in file and tgt_domain in file]
24
+
25
+ gold_paths = [file for file in os.listdir(parser_path) if file.endswith("gold.txt") and 'poetry' not in file and 'prose' not in file and 'extra' not in file and tgt_domain in file and 'train' not in file]
26
+ if use_labeled_data:
27
+ gold_paths += [file for file in os.listdir(parser_path) if file.endswith("gold.txt") and 'poetry' not in file and 'prose' not in file and 'extra' not in file and tgt_domain in file and 'train' in file]
28
+
29
+ if not use_unlabeled_data and not use_labeled_data:
30
+ raise ValueError
31
+ else:
32
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
33
+
34
+ gold_paths = []
35
+ if use_labeled_data:
36
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
37
+
38
+ if not use_unlabeled_data and not use_labeled_data:
39
+ raise ValueError
40
+
41
+ paths = pred_paths + gold_paths
42
+ print("Adding labels to paths: %s" % ', '.join(paths))
43
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
44
+ writing_paths = {}
45
+ sentences = {}
46
+ for path in paths:
47
+ if tgt_domain in path:
48
+ reading_path = parser_path + path
49
+ writing_path = model_path + 'parser_' + path
50
+ split = get_split(writing_path)
51
+ else:
52
+ reading_path = path
53
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
54
+ split = 'extra_train'
55
+ writing_paths[split] = writing_path
56
+ len_sent = 0
57
+ number_of_children = {}
58
+ lines = []
59
+ sentences_list = []
60
+ with open(reading_path, 'r') as file:
61
+ for line in file:
62
+ # line = line.decode('utf-8')
63
+ line = line.strip()
64
+ # print(line)
65
+ if len(line) == 0:
66
+ for idx in range(len_sent):
67
+ node = str(idx + 1)
68
+ if node not in number_of_children:
69
+ lines[idx].append('0')
70
+ else:
71
+ lines[idx].append(str(number_of_children[node]))
72
+ if len(lines) > 0:
73
+ tmp_root_line = root_line + [str(number_of_children['0'])]
74
+ sentences_list.append(tmp_root_line)
75
+ for line_ in lines:
76
+ sentences_list.append(line_)
77
+ sentences_list.append([])
78
+ lines = []
79
+ number_of_children = {}
80
+ len_sent = 0
81
+ continue
82
+ tokens = line.split('\t')
83
+ idx = tokens[0]
84
+ word = tokens[1]
85
+ pos = tokens[2]
86
+ ner = tokens[3]
87
+ head = tokens[4]
88
+ arc_tag = tokens[5]
89
+ if head not in number_of_children:
90
+ number_of_children[head] = 1
91
+ else:
92
+ number_of_children[head] += 1
93
+ lines.append([idx, word, pos, ner, head, arc_tag])
94
+ len_sent += 1
95
+ sentences[split] = sentences_list
96
+
97
+ train_sentences = []
98
+ if 'train' in sentences:
99
+ train_sentences = sentences['train']
100
+ else:
101
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
102
+ if 'extra_train' in sentences:
103
+ train_sentences += sentences['extra_train']
104
+ del writing_paths['extra_train']
105
+ if 'extra_dev' in sentences:
106
+ train_sentences += sentences['extra_dev']
107
+ del writing_paths['extra_dev']
108
+ with open(writing_paths['train'], 'w') as f:
109
+ for sent in train_sentences:
110
+ f.write('\t'.join(sent) + '\n')
111
+ for split in ['dev', 'test']:
112
+ if split in sentences:
113
+ split_sentences = sentences[split]
114
+ with open(writing_paths[split], 'w') as f:
115
+ for sent in split_sentences:
116
+ f.write('\t'.join(sent) + '\n')
117
+ return writing_paths
118
+
119
+
120
+ def add_distance_from_the_root(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
121
+ if src_domain == tgt_domain:
122
+ pred_paths = []
123
+ if use_unlabeled_data:
124
+ pred_paths = [file for file in os.listdir(parser_path) if
125
+ file.endswith("pred.txt") and 'poetry' not in file and 'prose' not in file and 'extra' in file and tgt_domain in file]
126
+
127
+ gold_paths = [file for file in os.listdir(parser_path) if
128
+ file.endswith("gold.txt") and 'poetry' not in file and 'prose' not in file and 'extra' not in file and tgt_domain in file and 'train' not in file]
129
+ if use_labeled_data:
130
+ gold_paths += [file for file in os.listdir(parser_path) if
131
+ file.endswith("gold.txt") and 'extra' not in file and 'poetry' not in file and 'prose' not in file and tgt_domain in file and 'train' in file]
132
+
133
+ if not use_unlabeled_data and not use_labeled_data:
134
+ raise ValueError
135
+ else:
136
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
137
+
138
+ gold_paths = []
139
+ if use_labeled_data:
140
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
141
+
142
+ if not use_unlabeled_data and not use_labeled_data:
143
+ raise ValueError
144
+
145
+ paths = pred_paths + gold_paths
146
+ print("Adding labels to paths: %s" % ', '.join(paths))
147
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root', '0']
148
+ writing_paths = {}
149
+ sentences = {}
150
+ for path in paths:
151
+ if tgt_domain in path:
152
+ reading_path = parser_path + path
153
+ writing_path = model_path + 'parser_' + path
154
+ split = get_split(writing_path)
155
+ else:
156
+ reading_path = path
157
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
158
+ split = 'extra_train'
159
+ writing_paths[split] = writing_path
160
+ len_sent = 0
161
+ tree_dict = {'0': '0'}
162
+ lines = []
163
+ sentences_list = []
164
+ with open(reading_path, 'r') as file:
165
+ for line in file:
166
+ # line = line.decode('utf-8')
167
+ line = line.strip()
168
+ if len(line) == 0:
169
+ for idx in range(len_sent):
170
+ depth = 1
171
+ node = str(idx + 1)
172
+ while tree_dict[node] != '0':
173
+ node = tree_dict[node]
174
+ depth += 1
175
+ lines[idx].append(str(depth))
176
+ if len(lines) > 0:
177
+ sentences_list.append(root_line)
178
+ for line_ in lines:
179
+ sentences_list.append(line_)
180
+ sentences_list.append([])
181
+ lines = []
182
+ tree_dict = {'0': '0'}
183
+ len_sent = 0
184
+ continue
185
+ tokens = line.split('\t')
186
+ idx = tokens[0]
187
+ word = tokens[1]
188
+ pos = tokens[2]
189
+ ner = tokens[3]
190
+ head = tokens[4]
191
+ arc_tag = tokens[5]
192
+ tree_dict[idx] = head
193
+ lines.append([idx, word, pos, ner, head, arc_tag])
194
+ len_sent += 1
195
+ sentences[split] = sentences_list
196
+
197
+ train_sentences = []
198
+ if 'train' in sentences:
199
+ train_sentences = sentences['train']
200
+ else:
201
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
202
+ if 'extra_train' in sentences:
203
+ train_sentences += sentences['extra_train']
204
+ del writing_paths['extra_train']
205
+ if 'extra_dev' in sentences:
206
+ train_sentences += sentences['extra_dev']
207
+ del writing_paths['extra_dev']
208
+ with open(writing_paths['train'], 'w') as f:
209
+ for sent in train_sentences:
210
+ f.write('\t'.join(sent) + '\n')
211
+ for split in ['dev', 'test']:
212
+ if split in sentences:
213
+ split_sentences = sentences[split]
214
+ with open(writing_paths[split], 'w') as f:
215
+ for sent in split_sentences:
216
+ f.write('\t'.join(sent) + '\n')
217
+ return writing_paths
218
+
219
+ def add_relative_pos_based(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
220
+ # most of the code for this function is taken from:
221
+ # https://github.com/mstrise/dep2label/blob/master/encoding.py
222
+ def pos_cluster(pos):
223
+ # clustering the parts of speech
224
+ if pos[0] == 'V':
225
+ pos = 'VB'
226
+ elif pos == 'NNS':
227
+ pos = 'NN'
228
+ elif pos == 'NNPS':
229
+ pos = 'NNP'
230
+ elif 'JJ' in pos:
231
+ pos = 'JJ'
232
+ elif pos[:2] == 'RB' or pos == 'WRB' or pos == 'RP':
233
+ pos = 'RB'
234
+ elif pos[:3] == 'PRP':
235
+ pos = 'PRP'
236
+ elif pos in ['.', ':', ',', "''", '``']:
237
+ pos = '.'
238
+ elif pos[0] == '-':
239
+ pos = '-RB-'
240
+ elif pos[:2] == 'WP':
241
+ pos = 'WP'
242
+ return pos
243
+
244
+ if src_domain == tgt_domain:
245
+ pred_paths = []
246
+ if use_unlabeled_data:
247
+ pred_paths = [file for file in os.listdir(parser_path) if
248
+ file.endswith("pred.txt") and 'poetry' not in file and 'prose' not in file and 'extra' in file and tgt_domain in file]
249
+
250
+ gold_paths = [file for file in os.listdir(parser_path) if
251
+ file.endswith("gold.txt") and 'poetry' not in file and 'prose' not in file and 'extra' not in file and tgt_domain in file and 'train' not in file]
252
+ if use_labeled_data:
253
+ gold_paths += [file for file in os.listdir(parser_path) if
254
+ file.endswith("gold.txt") and 'extra' not in file and 'poetry' not in file and 'prose' not in file and tgt_domain in file and 'train' in file]
255
+
256
+ if not use_unlabeled_data and not use_labeled_data:
257
+ raise ValueError
258
+ else:
259
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
260
+ gold_paths = []
261
+ if use_labeled_data:
262
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
263
+
264
+ if not use_unlabeled_data:
265
+ raise ValueError
266
+
267
+ paths = pred_paths + gold_paths
268
+ print("Adding labels to paths: %s" % ', '.join(paths))
269
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root', '+0_XX']
270
+ writing_paths = {}
271
+ sentences = {}
272
+ for path in paths:
273
+ if tgt_domain in path:
274
+ reading_path = parser_path + path
275
+ writing_path = model_path + 'parser_' + path
276
+ split = get_split(writing_path)
277
+ else:
278
+ reading_path = path
279
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
280
+ split = 'extra_train'
281
+ writing_paths[split] = writing_path
282
+ len_sent = 0
283
+ tree_dict = {'0': '0'}
284
+ lines = []
285
+ sentences_list = []
286
+ with open(reading_path, 'r') as file:
287
+ for line in file:
288
+ # line = line.decode('utf-8')
289
+ line = line.strip()
290
+ if len(line) == 0:
291
+ for idx in range(len_sent):
292
+ info_of_a_word = lines[idx]
293
+ # head is on the right side from the word
294
+ head = int(info_of_a_word[4]) - 1
295
+ if head == -1:
296
+ info_about_head = root_line
297
+ else:
298
+ info_about_head = lines[head]
299
+ if idx < head:
300
+ relative_position_head = 1
301
+ postag_head = pos_cluster(info_about_head[2])
302
+
303
+ for x in range(idx + 1, head):
304
+ another_word = lines[x]
305
+ postag_word_before_head = pos_cluster(another_word[2])
306
+ if postag_word_before_head == postag_head:
307
+ relative_position_head += 1
308
+ label = str(
309
+ "+" +
310
+ repr(relative_position_head) +
311
+ "_" +
312
+ postag_head)
313
+ lines[idx].append(label)
314
+
315
+ # head is on the left side from the word
316
+ elif idx > head:
317
+ relative_position_head = 1
318
+ postag_head = pos_cluster(info_about_head[2])
319
+ for x in range(head + 1, idx):
320
+ another_word = lines[x]
321
+ postag_word_before_head = pos_cluster(another_word[2])
322
+ if postag_word_before_head == postag_head:
323
+ relative_position_head += 1
324
+ label = str(
325
+ "-" +
326
+ repr(relative_position_head) +
327
+ "_" +
328
+ postag_head)
329
+ lines[idx].append(label)
330
+ if len(lines) > 0:
331
+ sentences_list.append(root_line)
332
+ for line_ in lines:
333
+ sentences_list.append(line_)
334
+ sentences_list.append([])
335
+ lines = []
336
+ tree_dict = {'0': '0'}
337
+ len_sent = 0
338
+ continue
339
+ tokens = line.split('\t')
340
+ idx = tokens[0]
341
+ word = tokens[1]
342
+ pos = tokens[2]
343
+ ner = tokens[3]
344
+ head = tokens[4]
345
+ arc_tag = tokens[5]
346
+ tree_dict[idx] = head
347
+ lines.append([idx, word, pos, ner, head, arc_tag])
348
+ len_sent += 1
349
+ sentences[split] = sentences_list
350
+
351
+ train_sentences = []
352
+ if 'train' in sentences:
353
+ train_sentences = sentences['train']
354
+ else:
355
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
356
+ if 'extra_train' in sentences:
357
+ train_sentences += sentences['extra_train']
358
+ del writing_paths['extra_train']
359
+ if 'extra_dev' in sentences:
360
+ train_sentences += sentences['extra_dev']
361
+ del writing_paths['extra_dev']
362
+ with open(writing_paths['train'], 'w') as f:
363
+ for sent in train_sentences:
364
+ f.write('\t'.join(sent) + '\n')
365
+ for split in ['dev', 'test']:
366
+ if split in sentences:
367
+ split_sentences = sentences[split]
368
+ with open(writing_paths[split], 'w') as f:
369
+ for sent in split_sentences:
370
+ f.write('\t'.join(sent) + '\n')
371
+ return writing_paths
372
+
373
+ def add_language_model(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
374
+ if src_domain == tgt_domain:
375
+ pred_paths = []
376
+ if use_unlabeled_data:
377
+ pred_paths = [file for file in os.listdir(parser_path) if
378
+ file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
379
+
380
+ gold_paths = [file for file in os.listdir(parser_path) if
381
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
382
+ if use_labeled_data:
383
+ gold_paths += [file for file in os.listdir(parser_path) if
384
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
385
+
386
+ if not use_unlabeled_data and not use_labeled_data:
387
+ raise ValueError
388
+ else:
389
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
390
+
391
+ gold_paths = []
392
+ if use_labeled_data:
393
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
394
+
395
+ if not use_unlabeled_data and not use_labeled_data:
396
+ raise ValueError
397
+
398
+ paths = pred_paths + gold_paths
399
+ print("Adding labels to paths: %s" % ', '.join(paths))
400
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
401
+ writing_paths = {}
402
+ sentences = {}
403
+ for path in paths:
404
+ if tgt_domain in path:
405
+ reading_path = parser_path + path
406
+ writing_path = model_path + 'parser_' + path
407
+ split = get_split(writing_path)
408
+ else:
409
+ reading_path = path
410
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
411
+ split = 'extra_train'
412
+ writing_paths[split] = writing_path
413
+ len_sent = 0
414
+ lines = []
415
+ sentences_list = []
416
+ with open(reading_path, 'r') as file:
417
+ for line in file:
418
+ # line = line.decode('utf-8')
419
+ line = line.strip()
420
+ if len(line) == 0:
421
+ for idx in range(len_sent):
422
+ if idx < len_sent - 1:
423
+ lines[idx].append(lines[idx+1][1])
424
+ else:
425
+ lines[idx].append(END)
426
+ if len(lines) > 0:
427
+ tmp_root_line = root_line + [lines[0][1]]
428
+ sentences_list.append(tmp_root_line)
429
+ for line_ in lines:
430
+ sentences_list.append(line_)
431
+ sentences_list.append([])
432
+ lines = []
433
+ len_sent = 0
434
+ continue
435
+ tokens = line.split('\t')
436
+ idx = tokens[0]
437
+ word = tokens[1]
438
+ pos = tokens[2]
439
+ ner = tokens[3]
440
+ head = tokens[4]
441
+ arc_tag = tokens[5]
442
+ lines.append([idx, word, pos, ner, head, arc_tag])
443
+ len_sent += 1
444
+ sentences[split] = sentences_list
445
+
446
+ train_sentences = []
447
+ if 'train' in sentences:
448
+ train_sentences = sentences['train']
449
+ else:
450
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
451
+ if 'extra_train' in sentences:
452
+ train_sentences += sentences['extra_train']
453
+ del writing_paths['extra_train']
454
+ if 'extra_dev' in sentences:
455
+ train_sentences += sentences['extra_dev']
456
+ del writing_paths['extra_dev']
457
+ with open(writing_paths['train'], 'w') as f:
458
+ for sent in train_sentences:
459
+ f.write('\t'.join(sent) + '\n')
460
+ for split in ['dev', 'test']:
461
+ if split in sentences:
462
+ split_sentences = sentences[split]
463
+ with open(writing_paths[split], 'w') as f:
464
+ for sent in split_sentences:
465
+ f.write('\t'.join(sent) + '\n')
466
+ return writing_paths
467
+
468
+ def add_relative_TAG(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
469
+ # most of the code for this function is taken from:
470
+ # https://github.com/mstrise/dep2label/blob/master/encoding.py
471
+ def pos_cluster(pos):
472
+ # clustering the parts of speech
473
+ if pos[0] == 'V':
474
+ pos = 'VB'
475
+ elif pos == 'NNS':
476
+ pos = 'NN'
477
+ elif pos == 'NNPS':
478
+ pos = 'NNP'
479
+ elif 'JJ' in pos:
480
+ pos = 'JJ'
481
+ elif pos[:2] == 'RB' or pos == 'WRB' or pos == 'RP':
482
+ pos = 'RB'
483
+ elif pos[:3] == 'PRP':
484
+ pos = 'PRP'
485
+ elif pos in ['.', ':', ',', "''", '``']:
486
+ pos = '.'
487
+ elif pos[0] == '-':
488
+ pos = '-RB-'
489
+ elif pos[:2] == 'WP':
490
+ pos = 'WP'
491
+ return pos
492
+
493
+ if src_domain == tgt_domain:
494
+ pred_paths = []
495
+ if use_unlabeled_data:
496
+ pred_paths = [file for file in os.listdir(parser_path) if
497
+ file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
498
+
499
+ gold_paths = [file for file in os.listdir(parser_path) if
500
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
501
+ if use_labeled_data:
502
+ gold_paths += [file for file in os.listdir(parser_path) if
503
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
504
+
505
+ if not use_unlabeled_data and not use_labeled_data:
506
+ raise ValueError
507
+ else:
508
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
509
+ gold_paths = []
510
+ if use_labeled_data:
511
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
512
+
513
+ if not use_unlabeled_data:
514
+ raise ValueError
515
+
516
+ paths = pred_paths + gold_paths
517
+ print("Adding labels to paths: %s" % ', '.join(paths))
518
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root', '+0_XX']
519
+ writing_paths = {}
520
+ sentences = {}
521
+ for path in paths:
522
+ if tgt_domain in path:
523
+ reading_path = parser_path + path
524
+ writing_path = model_path + 'parser_' + path
525
+ split = get_split(writing_path)
526
+ else:
527
+ reading_path = path
528
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
529
+ split = 'extra_train'
530
+ writing_paths[split] = writing_path
531
+ len_sent = 0
532
+ tree_dict = {'0': '0'}
533
+ lines = []
534
+ sentences_list = []
535
+ with open(reading_path, 'r') as file:
536
+
537
+ for line in file:
538
+ # print(line)
539
+ # print(reading_path)
540
+ # line = line.decode('utf-8')
541
+ line = line.strip()
542
+ if len(line) == 0:
543
+ for idx in range(len_sent):
544
+ info_of_a_word = lines[idx]
545
+ # head is on the right side from the word
546
+ head = int(info_of_a_word[4]) - 1
547
+ if head == -1:
548
+ info_about_head = root_line
549
+ else:
550
+ # print(len(lines), head)
551
+ info_about_head = lines[head]
552
+
553
+ if idx < head:
554
+ relative_position_head = 1
555
+ tag_head = info_about_head[5]
556
+
557
+ for x in range(idx + 1, head):
558
+ another_word = lines[x]
559
+ tag_word_before_head = another_word[5]
560
+ if tag_word_before_head == tag_head:
561
+ relative_position_head += 1
562
+ label = str(
563
+ "+" +
564
+ repr(relative_position_head) +
565
+ "_" +
566
+ tag_head)
567
+ lines[idx].append(label)
568
+
569
+ # head is on the left side from the word
570
+ elif idx > head:
571
+ relative_position_head = 1
572
+ tag_head = info_about_head[5]
573
+ for x in range(head + 1, idx):
574
+ another_word = lines[x]
575
+ tag_word_before_head = another_word[5]
576
+ if tag_word_before_head == tag_head:
577
+ relative_position_head += 1
578
+ label = str(
579
+ "-" +
580
+ repr(relative_position_head) +
581
+ "_" +
582
+ tag_head)
583
+ lines[idx].append(label)
584
+ if len(lines) > 0:
585
+ sentences_list.append(root_line)
586
+ for line_ in lines:
587
+ sentences_list.append(line_)
588
+ sentences_list.append([])
589
+ lines = []
590
+ tree_dict = {'0': '0'}
591
+ len_sent = 0
592
+ continue
593
+ tokens = line.split('\t')
594
+ idx = tokens[0]
595
+ word = tokens[1]
596
+ pos = tokens[2]
597
+ ner = tokens[3]
598
+ head = tokens[4]
599
+ arc_tag = tokens[5]
600
+ tree_dict[idx] = head
601
+ lines.append([idx, word, pos, ner, head, arc_tag])
602
+ len_sent += 1
603
+ sentences[split] = sentences_list
604
+
605
+ train_sentences = []
606
+ if 'train' in sentences:
607
+ train_sentences = sentences['train']
608
+ else:
609
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
610
+ if 'extra_train' in sentences:
611
+ train_sentences += sentences['extra_train']
612
+ del writing_paths['extra_train']
613
+ if 'extra_dev' in sentences:
614
+ train_sentences += sentences['extra_dev']
615
+ del writing_paths['extra_dev']
616
+ with open(writing_paths['train'], 'w') as f:
617
+ for sent in train_sentences:
618
+ f.write('\t'.join(sent) + '\n')
619
+ for split in ['dev', 'test']:
620
+ if split in sentences:
621
+ split_sentences = sentences[split]
622
+ with open(writing_paths[split], 'w') as f:
623
+ for sent in split_sentences:
624
+ f.write('\t'.join(sent) + '\n')
625
+ return writing_paths
626
+
627
+
628
+ def add_head(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
629
+ # most of the code for this function is taken from:
630
+ # https://github.com/mstrise/dep2label/blob/master/encoding.py
631
+ def pos_cluster(pos):
632
+ # clustering the parts of speech
633
+ if pos[0] == 'V':
634
+ pos = 'VB'
635
+ elif pos == 'NNS':
636
+ pos = 'NN'
637
+ elif pos == 'NNPS':
638
+ pos = 'NNP'
639
+ elif 'JJ' in pos:
640
+ pos = 'JJ'
641
+ elif pos[:2] == 'RB' or pos == 'WRB' or pos == 'RP':
642
+ pos = 'RB'
643
+ elif pos[:3] == 'PRP':
644
+ pos = 'PRP'
645
+ elif pos in ['.', ':', ',', "''", '``']:
646
+ pos = '.'
647
+ elif pos[0] == '-':
648
+ pos = '-RB-'
649
+ elif pos[:2] == 'WP':
650
+ pos = 'WP'
651
+ return pos
652
+
653
+ if src_domain == tgt_domain:
654
+ pred_paths = []
655
+ if use_unlabeled_data:
656
+ pred_paths = [file for file in os.listdir(parser_path) if
657
+ file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
658
+
659
+ gold_paths = [file for file in os.listdir(parser_path) if
660
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
661
+ if use_labeled_data:
662
+ gold_paths += [file for file in os.listdir(parser_path) if
663
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
664
+
665
+ if not use_unlabeled_data and not use_labeled_data:
666
+ raise ValueError
667
+ else:
668
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
669
+ gold_paths = []
670
+ if use_labeled_data:
671
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
672
+
673
+ if not use_unlabeled_data:
674
+ raise ValueError
675
+
676
+ paths = pred_paths + gold_paths
677
+ print("Adding labels to paths: %s" % ', '.join(paths))
678
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root', '+0_XX']
679
+ writing_paths = {}
680
+ sentences = {}
681
+ for path in paths:
682
+ if tgt_domain in path:
683
+ reading_path = parser_path + path
684
+ writing_path = model_path + 'parser_' + path
685
+ split = get_split(writing_path)
686
+ else:
687
+ reading_path = path
688
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
689
+ split = 'extra_train'
690
+ writing_paths[split] = writing_path
691
+ len_sent = 0
692
+ tree_dict = {'0': '0'}
693
+ lines = []
694
+ sentences_list = []
695
+ with open(reading_path, 'r') as file:
696
+
697
+ for line in file:
698
+ # print(line)
699
+ # print(reading_path)
700
+ # line = line.decode('utf-8')
701
+ line = line.strip()
702
+ if len(line) == 0:
703
+ for idx in range(len_sent):
704
+ info_of_a_word = lines[idx]
705
+ # head is on the right side from the word
706
+ head = int(info_of_a_word[4]) - 1
707
+ if head == -1:
708
+ info_about_head = root_line
709
+ else:
710
+ # print(len(lines), head)
711
+ info_about_head = lines[head]
712
+ head_word = info_about_head[1]
713
+ lines[idx].append(head_word)
714
+ # if idx < head:
715
+ # relative_position_head = 1
716
+
717
+
718
+ # for x in range(idx + 1, head):
719
+ # another_word = lines[x]
720
+ # postag_word_before_head = pos_cluster(another_word[2])
721
+ # if postag_word_before_head == postag_head:
722
+ # relative_position_head += 1
723
+ # label = str(
724
+ # "+" +
725
+ # repr(relative_position_head) +
726
+ # "_" +
727
+ # postag_head)
728
+
729
+
730
+ # # head is on the left side from the word
731
+ # elif idx > head:
732
+ # relative_position_head = 1
733
+ # postag_head = pos_cluster(info_about_head[2])
734
+ # for x in range(head + 1, idx):
735
+ # another_word = lines[x]
736
+ # postag_word_before_head = pos_cluster(another_word[2])
737
+ # if postag_word_before_head == postag_head:
738
+ # relative_position_head += 1
739
+ # label = str(
740
+ # "-" +
741
+ # repr(relative_position_head) +
742
+ # "_" +
743
+ # postag_head)
744
+ # lines[idx].append(label)
745
+ if len(lines) > 0:
746
+ sentences_list.append(root_line)
747
+ for line_ in lines:
748
+ sentences_list.append(line_)
749
+ sentences_list.append([])
750
+ lines = []
751
+ tree_dict = {'0': '0'}
752
+ len_sent = 0
753
+ continue
754
+ tokens = line.split('\t')
755
+ idx = tokens[0]
756
+ word = tokens[1]
757
+ pos = tokens[2]
758
+ ner = tokens[3]
759
+ head = tokens[4]
760
+ arc_tag = tokens[5]
761
+ tree_dict[idx] = head
762
+ lines.append([idx, word, pos, ner, head, arc_tag])
763
+ len_sent += 1
764
+ sentences[split] = sentences_list
765
+
766
+ train_sentences = []
767
+ if 'train' in sentences:
768
+ train_sentences = sentences['train']
769
+ else:
770
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
771
+ if 'extra_train' in sentences:
772
+ train_sentences += sentences['extra_train']
773
+ del writing_paths['extra_train']
774
+ if 'extra_dev' in sentences:
775
+ train_sentences += sentences['extra_dev']
776
+ del writing_paths['extra_dev']
777
+ with open(writing_paths['train'], 'w') as f:
778
+ for sent in train_sentences:
779
+ f.write('\t'.join(sent) + '\n')
780
+ for split in ['dev', 'test']:
781
+ if split in sentences:
782
+ split_sentences = sentences[split]
783
+ with open(writing_paths[split], 'w') as f:
784
+ for sent in split_sentences:
785
+ f.write('\t'.join(sent) + '\n')
786
+ return writing_paths
787
+ import json
788
+ def get_modified_coarse(ma):
789
+ ma = ma.replace('sgpl','sg').replace('sgdu','sg')
790
+ with open('/home/jivnesh/DCST_scratch/utils/io_/coarse_to_ma_dict.json', 'r') as fh:
791
+ coarse_dict = json.load(fh)
792
+ for key in coarse_dict.keys():
793
+ if ma in coarse_dict[key]:
794
+ return key
795
+ def add_head_coarse_pos(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
796
+ # most of the code for this function is taken from:
797
+ # https://github.com/mstrise/dep2label/blob/master/encoding.py
798
+ def pos_cluster(pos):
799
+ # clustering the parts of speech
800
+ if pos[0] == 'V':
801
+ pos = 'VB'
802
+ elif pos == 'NNS':
803
+ pos = 'NN'
804
+ elif pos == 'NNPS':
805
+ pos = 'NNP'
806
+ elif 'JJ' in pos:
807
+ pos = 'JJ'
808
+ elif pos[:2] == 'RB' or pos == 'WRB' or pos == 'RP':
809
+ pos = 'RB'
810
+ elif pos[:3] == 'PRP':
811
+ pos = 'PRP'
812
+ elif pos in ['.', ':', ',', "''", '``']:
813
+ pos = '.'
814
+ elif pos[0] == '-':
815
+ pos = '-RB-'
816
+ elif pos[:2] == 'WP':
817
+ pos = 'WP'
818
+ return pos
819
+
820
+ if src_domain == tgt_domain:
821
+ pred_paths = []
822
+ if use_unlabeled_data:
823
+ pred_paths = [file for file in os.listdir(parser_path) if
824
+ file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
825
+
826
+ gold_paths = [file for file in os.listdir(parser_path) if
827
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
828
+ if use_labeled_data:
829
+ gold_paths += [file for file in os.listdir(parser_path) if
830
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
831
+
832
+ if not use_unlabeled_data and not use_labeled_data:
833
+ raise ValueError
834
+ else:
835
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
836
+ gold_paths = []
837
+ if use_labeled_data:
838
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
839
+
840
+ if not use_unlabeled_data:
841
+ raise ValueError
842
+
843
+ paths = pred_paths + gold_paths
844
+ print("Adding labels to paths: %s" % ', '.join(paths))
845
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root', 'O']
846
+ writing_paths = {}
847
+ sentences = {}
848
+ for path in paths:
849
+ if tgt_domain in path:
850
+ reading_path = parser_path + path
851
+ writing_path = model_path + 'parser_' + path
852
+ split = get_split(writing_path)
853
+ else:
854
+ reading_path = path
855
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
856
+ split = 'extra_train'
857
+ writing_paths[split] = writing_path
858
+ len_sent = 0
859
+ tree_dict = {'0': '0'}
860
+ lines = []
861
+ sentences_list = []
862
+ with open(reading_path, 'r') as file:
863
+
864
+ for line in file:
865
+ # print(line)
866
+ # print(reading_path)
867
+ # line = line.decode('utf-8')
868
+ line = line.strip()
869
+ if len(line) == 0:
870
+ for idx in range(len_sent):
871
+ info_of_a_word = lines[idx]
872
+ # head is on the right side from the word
873
+ head = int(info_of_a_word[4]) - 1
874
+ if head == -1:
875
+ info_about_head = root_line
876
+ else:
877
+ # print(len(lines), head)
878
+ info_about_head = lines[head]
879
+ postag_head = info_about_head[2]
880
+ lines[idx].append(postag_head)
881
+ if len(lines) > 0:
882
+ sentences_list.append(root_line)
883
+ for line_ in lines:
884
+ sentences_list.append(line_)
885
+ sentences_list.append([])
886
+ lines = []
887
+ tree_dict = {'0': '0'}
888
+ len_sent = 0
889
+ continue
890
+ tokens = line.split('\t')
891
+ idx = tokens[0]
892
+ word = tokens[1]
893
+ pos = tokens[2]
894
+ ner = tokens[3]
895
+ head = tokens[4]
896
+ arc_tag = tokens[5]
897
+ tree_dict[idx] = head
898
+ lines.append([idx, word, pos, ner, head, arc_tag])
899
+ len_sent += 1
900
+ sentences[split] = sentences_list
901
+
902
+ train_sentences = []
903
+ if 'train' in sentences:
904
+ train_sentences = sentences['train']
905
+ else:
906
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
907
+ if 'extra_train' in sentences:
908
+ train_sentences += sentences['extra_train']
909
+ del writing_paths['extra_train']
910
+ if 'extra_dev' in sentences:
911
+ train_sentences += sentences['extra_dev']
912
+ del writing_paths['extra_dev']
913
+ with open(writing_paths['train'], 'w') as f:
914
+ for sent in train_sentences:
915
+ f.write('\t'.join(sent) + '\n')
916
+ for split in ['dev', 'test']:
917
+ if split in sentences:
918
+ split_sentences = sentences[split]
919
+ with open(writing_paths[split], 'w') as f:
920
+ for sent in split_sentences:
921
+ f.write('\t'.join(sent) + '\n')
922
+ return writing_paths
923
+
924
+ def add_head_ma(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
925
+ # most of the code for this function is taken from:
926
+ # https://github.com/mstrise/dep2label/blob/master/encoding.py
927
+ def pos_cluster(pos):
928
+ # clustering the parts of speech
929
+ if pos[0] == 'V':
930
+ pos = 'VB'
931
+ elif pos == 'NNS':
932
+ pos = 'NN'
933
+ elif pos == 'NNPS':
934
+ pos = 'NNP'
935
+ elif 'JJ' in pos:
936
+ pos = 'JJ'
937
+ elif pos[:2] == 'RB' or pos == 'WRB' or pos == 'RP':
938
+ pos = 'RB'
939
+ elif pos[:3] == 'PRP':
940
+ pos = 'PRP'
941
+ elif pos in ['.', ':', ',', "''", '``']:
942
+ pos = '.'
943
+ elif pos[0] == '-':
944
+ pos = '-RB-'
945
+ elif pos[:2] == 'WP':
946
+ pos = 'WP'
947
+ return pos
948
+
949
+ if src_domain == tgt_domain:
950
+ pred_paths = []
951
+ if use_unlabeled_data:
952
+ pred_paths = [file for file in os.listdir(parser_path) if
953
+ file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
954
+
955
+ gold_paths = [file for file in os.listdir(parser_path) if
956
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
957
+ if use_labeled_data:
958
+ gold_paths += [file for file in os.listdir(parser_path) if
959
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
960
+
961
+ if not use_unlabeled_data and not use_labeled_data:
962
+ raise ValueError
963
+ else:
964
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
965
+ gold_paths = []
966
+ if use_labeled_data:
967
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
968
+
969
+ if not use_unlabeled_data:
970
+ raise ValueError
971
+
972
+ paths = pred_paths + gold_paths
973
+ print("Adding labels to paths: %s" % ', '.join(paths))
974
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root', 'XX']
975
+ writing_paths = {}
976
+ sentences = {}
977
+ for path in paths:
978
+ if tgt_domain in path:
979
+ reading_path = parser_path + path
980
+ writing_path = model_path + 'parser_' + path
981
+ split = get_split(writing_path)
982
+ else:
983
+ reading_path = path
984
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
985
+ split = 'extra_train'
986
+ writing_paths[split] = writing_path
987
+ len_sent = 0
988
+ tree_dict = {'0': '0'}
989
+ lines = []
990
+ sentences_list = []
991
+ with open(reading_path, 'r') as file:
992
+
993
+ for line in file:
994
+ # print(line)
995
+ # print(reading_path)
996
+ # line = line.decode('utf-8')
997
+ line = line.strip()
998
+ if len(line) == 0:
999
+ for idx in range(len_sent):
1000
+ info_of_a_word = lines[idx]
1001
+ # head is on the right side from the word
1002
+ head = int(info_of_a_word[4]) - 1
1003
+ if head == -1:
1004
+ info_about_head = root_line
1005
+ else:
1006
+ # print(len(lines), head)
1007
+ info_about_head = lines[head]
1008
+ postag_head = pos_cluster(info_about_head[2])
1009
+ lines[idx].append(postag_head)
1010
+ if len(lines) > 0:
1011
+ sentences_list.append(root_line)
1012
+ for line_ in lines:
1013
+ sentences_list.append(line_)
1014
+ sentences_list.append([])
1015
+ lines = []
1016
+ tree_dict = {'0': '0'}
1017
+ len_sent = 0
1018
+ continue
1019
+ tokens = line.split('\t')
1020
+ idx = tokens[0]
1021
+ word = tokens[1]
1022
+ pos = tokens[2]
1023
+ ner = tokens[3]
1024
+ head = tokens[4]
1025
+ arc_tag = tokens[5]
1026
+ tree_dict[idx] = head
1027
+ lines.append([idx, word, pos, ner, head, arc_tag])
1028
+ len_sent += 1
1029
+ sentences[split] = sentences_list
1030
+
1031
+ train_sentences = []
1032
+ if 'train' in sentences:
1033
+ train_sentences = sentences['train']
1034
+ else:
1035
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
1036
+ if 'extra_train' in sentences:
1037
+ train_sentences += sentences['extra_train']
1038
+ del writing_paths['extra_train']
1039
+ if 'extra_dev' in sentences:
1040
+ train_sentences += sentences['extra_dev']
1041
+ del writing_paths['extra_dev']
1042
+ with open(writing_paths['train'], 'w') as f:
1043
+ for sent in train_sentences:
1044
+ f.write('\t'.join(sent) + '\n')
1045
+ for split in ['dev', 'test']:
1046
+ if split in sentences:
1047
+ split_sentences = sentences[split]
1048
+ with open(writing_paths[split], 'w') as f:
1049
+ for sent in split_sentences:
1050
+ f.write('\t'.join(sent) + '\n')
1051
+ return writing_paths
1052
+
1053
+ def add_label(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1054
+ if src_domain == tgt_domain:
1055
+ pred_paths = []
1056
+ if use_unlabeled_data:
1057
+ pred_paths = [file for file in os.listdir(parser_path) if
1058
+ file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
1059
+
1060
+ gold_paths = [file for file in os.listdir(parser_path) if
1061
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
1062
+ if use_labeled_data:
1063
+ gold_paths += [file for file in os.listdir(parser_path) if
1064
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
1065
+
1066
+ if not use_unlabeled_data and not use_labeled_data:
1067
+ raise ValueError
1068
+ else:
1069
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
1070
+
1071
+ gold_paths = []
1072
+ if use_labeled_data:
1073
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
1074
+
1075
+ if not use_unlabeled_data and not use_labeled_data:
1076
+ raise ValueError
1077
+
1078
+ paths = pred_paths + gold_paths
1079
+ print('############ Add Label Task #################')
1080
+ print("Adding labels to paths: %s" % ', '.join(paths))
1081
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
1082
+ writing_paths = {}
1083
+ sentences = {}
1084
+ for path in paths:
1085
+ if tgt_domain in path:
1086
+ reading_path = parser_path + path
1087
+ writing_path = model_path + 'parser_' + path
1088
+ split = get_split(writing_path)
1089
+ else:
1090
+ reading_path = path
1091
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
1092
+ split = 'extra_train'
1093
+ writing_paths[split] = writing_path
1094
+ len_sent = 0
1095
+ lines = []
1096
+ sentences_list = []
1097
+ with open(reading_path, 'r') as file:
1098
+ for line in file:
1099
+ # line = line.decode('utf-8')
1100
+ line = line.strip()
1101
+ # Now blank space got detected
1102
+ if len(line) == 0:
1103
+ # Append next word to last column
1104
+ for idx in range(len_sent):
1105
+ lines[idx].append(lines[idx][5])
1106
+ # Add root line first
1107
+ if len(lines) > 0:
1108
+ tmp_root_line = root_line + [root_line[5]]
1109
+ sentences_list.append(tmp_root_line)
1110
+ for line_ in lines:
1111
+ sentences_list.append(line_)
1112
+ sentences_list.append([])
1113
+ lines = []
1114
+ len_sent = 0
1115
+ continue
1116
+ tokens = line.split('\t')
1117
+ idx = tokens[0]
1118
+ word = tokens[1]
1119
+ pos = tokens[2]
1120
+ ner = tokens[3]
1121
+ head = tokens[4]
1122
+ arc_tag = tokens[5]
1123
+ lines.append([idx, word, pos, ner, head, arc_tag])
1124
+ len_sent += 1
1125
+ sentences[split] = sentences_list
1126
+
1127
+ train_sentences = []
1128
+ if 'train' in sentences:
1129
+ train_sentences = sentences['train']
1130
+ else:
1131
+ # pdb.set_trace()
1132
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
1133
+ if 'extra_train' in sentences:
1134
+ train_sentences += sentences['extra_train']
1135
+ del writing_paths['extra_train']
1136
+ if 'extra_dev' in sentences:
1137
+ train_sentences += sentences['extra_dev']
1138
+ del writing_paths['extra_dev']
1139
+ with open(writing_paths['train'], 'w') as f:
1140
+ for sent in train_sentences:
1141
+ f.write('\t'.join(sent) + '\n')
1142
+ for split in ['dev', 'test']:
1143
+ if split in sentences:
1144
+ split_sentences = sentences[split]
1145
+ with open(writing_paths[split], 'w') as f:
1146
+ for sent in split_sentences:
1147
+ f.write('\t'.join(sent) + '\n')
1148
+ return writing_paths
1149
+
1150
+ def predict_ma_tag_of_modifier(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1151
+ if src_domain == tgt_domain:
1152
+ pred_paths = []
1153
+ if use_unlabeled_data:
1154
+ pred_paths = [file for file in os.listdir(parser_path) if
1155
+ file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
1156
+
1157
+ gold_paths = [file for file in os.listdir(parser_path) if
1158
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
1159
+ if use_labeled_data:
1160
+ gold_paths += [file for file in os.listdir(parser_path) if
1161
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
1162
+
1163
+ if not use_unlabeled_data and not use_labeled_data:
1164
+ raise ValueError
1165
+ else:
1166
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
1167
+
1168
+ gold_paths = []
1169
+ if use_labeled_data:
1170
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
1171
+
1172
+ if not use_unlabeled_data and not use_labeled_data:
1173
+ raise ValueError
1174
+
1175
+ paths = pred_paths + gold_paths
1176
+ print('############ Add Label Task #################')
1177
+ print("Adding labels to paths: %s" % ', '.join(paths))
1178
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
1179
+ writing_paths = {}
1180
+ sentences = {}
1181
+ for path in paths:
1182
+ if tgt_domain in path:
1183
+ reading_path = parser_path + path
1184
+ writing_path = model_path + 'parser_' + path
1185
+ split = get_split(writing_path)
1186
+ else:
1187
+ reading_path = path
1188
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
1189
+ split = 'extra_train'
1190
+ writing_paths[split] = writing_path
1191
+ len_sent = 0
1192
+ lines = []
1193
+ sentences_list = []
1194
+ with open(reading_path, 'r') as file:
1195
+ for line in file:
1196
+ # line = line.decode('utf-8')
1197
+ line = line.strip()
1198
+ # Now blank space got detected
1199
+ if len(line) == 0:
1200
+ # Append next word to last column
1201
+ for idx in range(len_sent):
1202
+ lines[idx].append(clean_ma(lines[idx][3]))
1203
+ # Add root line first
1204
+ if len(lines) > 0:
1205
+ tmp_root_line = root_line + [root_line[3]]
1206
+ sentences_list.append(tmp_root_line)
1207
+ for line_ in lines:
1208
+ sentences_list.append(line_)
1209
+ sentences_list.append([])
1210
+ lines = []
1211
+ len_sent = 0
1212
+ continue
1213
+ tokens = line.split('\t')
1214
+ idx = tokens[0]
1215
+ word = tokens[1]
1216
+ pos = tokens[2]
1217
+ ner = tokens[3]
1218
+ head = tokens[4]
1219
+ arc_tag = tokens[5]
1220
+ lines.append([idx, word, pos, ner, head, arc_tag])
1221
+ len_sent += 1
1222
+ sentences[split] = sentences_list
1223
+
1224
+ train_sentences = []
1225
+ if 'train' in sentences:
1226
+ train_sentences = sentences['train']
1227
+ else:
1228
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
1229
+ if 'extra_train' in sentences:
1230
+ train_sentences += sentences['extra_train']
1231
+ del writing_paths['extra_train']
1232
+ if 'extra_dev' in sentences:
1233
+ train_sentences += sentences['extra_dev']
1234
+ del writing_paths['extra_dev']
1235
+ with open(writing_paths['train'], 'w') as f:
1236
+ for sent in train_sentences:
1237
+ f.write('\t'.join(sent) + '\n')
1238
+ for split in ['dev', 'test']:
1239
+ if split in sentences:
1240
+ split_sentences = sentences[split]
1241
+ with open(writing_paths[split], 'w') as f:
1242
+ for sent in split_sentences:
1243
+ f.write('\t'.join(sent) + '\n')
1244
+ return writing_paths
1245
+
1246
+ def predict_coarse_of_modifier(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1247
+ if src_domain == tgt_domain:
1248
+ pred_paths = []
1249
+ if use_unlabeled_data:
1250
+ pred_paths = [file for file in os.listdir(parser_path) if
1251
+ file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
1252
+
1253
+ gold_paths = [file for file in os.listdir(parser_path) if
1254
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
1255
+ if use_labeled_data:
1256
+ gold_paths += [file for file in os.listdir(parser_path) if
1257
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
1258
+
1259
+ if not use_unlabeled_data and not use_labeled_data:
1260
+ raise ValueError
1261
+ else:
1262
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
1263
+
1264
+ gold_paths = []
1265
+ if use_labeled_data:
1266
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
1267
+
1268
+ if not use_unlabeled_data and not use_labeled_data:
1269
+ raise ValueError
1270
+
1271
+ paths = pred_paths + gold_paths
1272
+ print('############ Add Label Task #################')
1273
+ print("Adding labels to paths: %s" % ', '.join(paths))
1274
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
1275
+ writing_paths = {}
1276
+ sentences = {}
1277
+ for path in paths:
1278
+ if tgt_domain in path:
1279
+ reading_path = parser_path + path
1280
+ writing_path = model_path + 'parser_' + path
1281
+ split = get_split(writing_path)
1282
+ else:
1283
+ reading_path = path
1284
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
1285
+ split = 'extra_train'
1286
+ writing_paths[split] = writing_path
1287
+ len_sent = 0
1288
+ lines = []
1289
+ sentences_list = []
1290
+ with open(reading_path, 'r') as file:
1291
+ for line in file:
1292
+ # line = line.decode('utf-8')
1293
+ line = line.strip()
1294
+ # Now blank space got detected
1295
+ if len(line) == 0:
1296
+ # Append next word to last column
1297
+ for idx in range(len_sent):
1298
+ lines[idx].append(lines[idx][3])
1299
+ # Add root line first
1300
+ if len(lines) > 0:
1301
+ tmp_root_line = root_line + [root_line[3]]
1302
+ sentences_list.append(tmp_root_line)
1303
+ for line_ in lines:
1304
+ sentences_list.append(line_)
1305
+ sentences_list.append([])
1306
+ lines = []
1307
+ len_sent = 0
1308
+ continue
1309
+ tokens = line.split('\t')
1310
+ idx = tokens[0]
1311
+ word = tokens[1]
1312
+ pos = tokens[2]
1313
+ ner = tokens[3]
1314
+ head = tokens[4]
1315
+ arc_tag = tokens[5]
1316
+ lines.append([idx, word, pos, ner, head, arc_tag])
1317
+ len_sent += 1
1318
+ sentences[split] = sentences_list
1319
+
1320
+ train_sentences = []
1321
+ if 'train' in sentences:
1322
+ train_sentences = sentences['train']
1323
+ else:
1324
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
1325
+ if 'extra_train' in sentences:
1326
+ train_sentences += sentences['extra_train']
1327
+ del writing_paths['extra_train']
1328
+ if 'extra_dev' in sentences:
1329
+ train_sentences += sentences['extra_dev']
1330
+ del writing_paths['extra_dev']
1331
+ with open(writing_paths['train'], 'w') as f:
1332
+ for sent in train_sentences:
1333
+ f.write('\t'.join(sent) + '\n')
1334
+ for split in ['dev', 'test']:
1335
+ if split in sentences:
1336
+ split_sentences = sentences[split]
1337
+ with open(writing_paths[split], 'w') as f:
1338
+ for sent in split_sentences:
1339
+ f.write('\t'.join(sent) + '\n')
1340
+ return writing_paths
1341
+ import re
1342
+ def get_case(ma):
1343
+ indeclinable = ['ind','prep','interj','prep','conj','part']
1344
+ case_list = ['nom','voc','acc','i','inst','dat','abl','g','loc']
1345
+ gender_list = ['n','f','m','*']
1346
+ person_list = ['1','2','3']
1347
+ no_list = ['du','sg','pl']
1348
+ pops = [' ac',' ps']
1349
+ ma=ma.replace('sgpl','sg').replace('sgdu','sg')
1350
+ temp = re.sub("([\(\[]).*?([\)\]])", "\g<1>\g<2>", ma).replace('[] ','').strip(' []')
1351
+ temp = temp.split('.')
1352
+ if temp[-1] == '':
1353
+ temp.pop(-1)
1354
+ # Remove active passive
1355
+ case=''
1356
+ no=''
1357
+ person=''
1358
+ gender=''
1359
+ tense=''
1360
+ coarse=''
1361
+ for a,b in enumerate(temp):
1362
+ if b in pops:
1363
+ temp.pop(a)
1364
+ # Get gender
1365
+ for a,b in enumerate(temp):
1366
+ if b.strip() in gender_list:
1367
+ gender = b.strip()
1368
+ temp.pop(a)
1369
+ # Get case
1370
+ for a,b in enumerate(temp):
1371
+ if b.strip() in case_list:
1372
+ case = b.strip()
1373
+ temp.pop(a)
1374
+ if case!= '':
1375
+ coarse ='Noun'
1376
+ # Get person
1377
+ for a,b in enumerate(temp):
1378
+ if b.strip() in person_list:
1379
+ person = b.strip()
1380
+ temp.pop(a)
1381
+ # Get no
1382
+ for a,b in enumerate(temp):
1383
+ if b.strip() in no_list:
1384
+ no = b.strip()
1385
+ temp.pop(a)
1386
+ # Get Tense
1387
+ for b in temp:
1388
+ tense=tense+ ' '+b.strip()
1389
+ tense=tense.strip()
1390
+
1391
+ # print(tense)
1392
+ if tense == 'adv':
1393
+ coarse = 'adv'
1394
+ for ind in indeclinable:
1395
+ if tense == ind:
1396
+ coarse = 'Ind'
1397
+ if tense == 'abs' or tense == 'ca abs':
1398
+ coarse = 'IV'
1399
+ if tense!='' and coarse=='':
1400
+ if person !='' or no!='':
1401
+ coarse= 'FV'
1402
+ else:
1403
+ coarse = 'IV'
1404
+ if case == 'i':
1405
+ return 'inst'
1406
+
1407
+ if case !='':
1408
+ return case
1409
+ else:
1410
+ return coarse
1411
+ def clean_ma(ma):
1412
+ ma = re.sub("([\(\[]).*?([\)\]])", "\g<1>\g<2>", ma).replace('[] ','').strip(' []').replace(' ac','').replace(' ps','').replace('sgpl','sg').replace('sgdu','sg')
1413
+ ma = ma.replace('i.','inst.').replace('.','').replace(' ','')
1414
+ return ma
1415
+ def predict_case_of_modifier(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1416
+ if src_domain == tgt_domain:
1417
+ pred_paths = []
1418
+ if use_unlabeled_data:
1419
+ pred_paths = [file for file in os.listdir(parser_path) if
1420
+ file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
1421
+
1422
+ gold_paths = [file for file in os.listdir(parser_path) if
1423
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
1424
+ if use_labeled_data:
1425
+ gold_paths += [file for file in os.listdir(parser_path) if
1426
+ file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
1427
+
1428
+ if not use_unlabeled_data and not use_labeled_data:
1429
+ raise ValueError
1430
+ else:
1431
+ pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
1432
+
1433
+ gold_paths = []
1434
+ if use_labeled_data:
1435
+ gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
1436
+
1437
+ if not use_unlabeled_data and not use_labeled_data:
1438
+ raise ValueError
1439
+
1440
+ paths = pred_paths + gold_paths
1441
+ print('############ Add Label Task #################')
1442
+ print("Adding labels to paths: %s" % ', '.join(paths))
1443
+ root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
1444
+ writing_paths = {}
1445
+ sentences = {}
1446
+ for path in paths:
1447
+ if tgt_domain in path:
1448
+ reading_path = parser_path + path
1449
+ writing_path = model_path + 'parser_' + path
1450
+ split = get_split(writing_path)
1451
+ else:
1452
+ reading_path = path
1453
+ writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
1454
+ split = 'extra_train'
1455
+ writing_paths[split] = writing_path
1456
+ len_sent = 0
1457
+ lines = []
1458
+ sentences_list = []
1459
+ with open(reading_path, 'r') as file:
1460
+ for line in file:
1461
+ # line = line.decode('utf-8')
1462
+ line = line.strip()
1463
+ # Now blank space got detected
1464
+ if len(line) == 0:
1465
+ # Append next word to last column
1466
+ for idx in range(len_sent):
1467
+ lines[idx].append(get_case(lines[idx][3]))
1468
+ # Add root line first
1469
+ if len(lines) > 0:
1470
+ tmp_root_line = root_line + [root_line[3]]
1471
+ sentences_list.append(tmp_root_line)
1472
+ for line_ in lines:
1473
+ sentences_list.append(line_)
1474
+ sentences_list.append([])
1475
+ lines = []
1476
+ len_sent = 0
1477
+ continue
1478
+ tokens = line.split('\t')
1479
+ idx = tokens[0]
1480
+ word = tokens[1]
1481
+ pos = tokens[2]
1482
+ ner = tokens[3]
1483
+ head = tokens[4]
1484
+ arc_tag = tokens[5]
1485
+ lines.append([idx, word, pos, ner, head, arc_tag])
1486
+ len_sent += 1
1487
+ sentences[split] = sentences_list
1488
+
1489
+ train_sentences = []
1490
+ if 'train' in sentences:
1491
+ train_sentences = sentences['train']
1492
+ else:
1493
+ writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
1494
+ if 'extra_train' in sentences:
1495
+ train_sentences += sentences['extra_train']
1496
+ del writing_paths['extra_train']
1497
+ if 'extra_dev' in sentences:
1498
+ train_sentences += sentences['extra_dev']
1499
+ del writing_paths['extra_dev']
1500
+ with open(writing_paths['train'], 'w') as f:
1501
+ for sent in train_sentences:
1502
+ f.write('\t'.join(sent) + '\n')
1503
+ for split in ['dev', 'test']:
1504
+ if split in sentences:
1505
+ split_sentences = sentences[split]
1506
+ with open(writing_paths[split], 'w') as f:
1507
+ for sent in split_sentences:
1508
+ f.write('\t'.join(sent) + '\n')
1509
+ return writing_paths
1510
+
1511
+ def Multitask_case_predict(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1512
+ writing_paths = {}
1513
+ # multitask_silver_20ktrain_san
1514
+ writing_paths['train'] = 'data/ud_pos_ner_dp_train_san_case'
1515
+ writing_paths['dev'] = 'data/ud_pos_ner_dp_dev_san_case'
1516
+ writing_paths['test'] = 'data/ud_pos_ner_dp_test_san_case'
1517
+ return writing_paths
1518
+
1519
+ def Multitask_POS_predict(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1520
+ writing_paths = {}
1521
+ # multitask_silver_20ktrain_san
1522
+ # writing_paths['train'] = 'data/Multitask_POS_predict_train_san'
1523
+ # writing_paths['dev'] = 'data/Multitask_POS_predict_dev_san'
1524
+ # writing_paths['test'] = 'data/Multitask_POS_predict_test_san'
1525
+ writing_paths['train'] = 'data/ud_pos_ner_dp_train_san_POS'
1526
+ writing_paths['dev'] = 'data/ud_pos_ner_dp_dev_san_POS'
1527
+ writing_paths['test'] = 'data/ud_pos_ner_dp_test_san_POS'
1528
+ return writing_paths
1529
+
1530
+ def Multitask_coarse_predict(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1531
+ writing_paths = {}
1532
+ # multitask_silver_20ktrain_san
1533
+ writing_paths['train'] = 'data/Multitask_coarse_predict_train_san'
1534
+ writing_paths['dev'] = 'data/Multitask_coarse_predict_dev_san'
1535
+ writing_paths['test'] = 'data/Multitask_coarse_predict_test_san'
1536
+ return writing_paths
1537
+
1538
+ def Multitask_label_predict(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1539
+ writing_paths = {}
1540
+ # multitask_silver_20ktrain_san
1541
+ writing_paths['train'] = 'data/Multitask_label_predict_train_san'
1542
+ writing_paths['dev'] = 'data/Multitask_label_predict_dev_san'
1543
+ writing_paths['test'] = 'data/Multitask_label_predict_test_san'
1544
+ return writing_paths
1545
+
1546
+
1547
+ def MRL_case(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1548
+ writing_paths = {}
1549
+ # multitask_silver_20ktrain_san
1550
+ writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_case'
1551
+ writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_case'
1552
+ writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_case'
1553
+ return writing_paths
1554
+
1555
+ def MRL_POS(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1556
+ writing_paths = {}
1557
+ # multitask_silver_20ktrain_san
1558
+ writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_POS'
1559
+ writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_POS'
1560
+ writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_POS'
1561
+ return writing_paths
1562
+
1563
+ def MRL_label(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1564
+ writing_paths = {}
1565
+ # multitask_silver_20ktrain_san
1566
+ writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_dep'
1567
+ writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_dep'
1568
+ writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_dep'
1569
+ return writing_paths
1570
+
1571
+ def MRL_no(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1572
+ writing_paths = {}
1573
+ # multitask_silver_20ktrain_san
1574
+ writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_no'
1575
+ writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_no'
1576
+ writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_no'
1577
+ return writing_paths
1578
+
1579
+ def MRL_Person(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1580
+ writing_paths = {}
1581
+ # multitask_silver_20ktrain_san
1582
+ writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_per'
1583
+ writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_per'
1584
+ writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_per'
1585
+ return writing_paths
1586
+ def MRL_Gender(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
1587
+ writing_paths = {}
1588
+ # multitask_silver_20ktrain_san
1589
+ writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_gen'
1590
+ writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_gen'
1591
+ writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_gen'
1592
+ return writing_paths
utils/io_/writer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class Writer(object):
3
+ def __init__(self, alphabets):
4
+ self.__source_file = None
5
+ self.alphabets = alphabets
6
+
7
+ def start(self, file_path):
8
+ self.__source_file = open(file_path, 'w')
9
+
10
+ def close(self):
11
+ self.__source_file.close()
12
+
13
+ def write(self, word, pos, ner, head, arc, lengths, auto_label=None, symbolic_root=False, symbolic_end=False):
14
+ batch_size, _ = word.shape
15
+ start = 1 if symbolic_root else 0
16
+ end = 1 if symbolic_end else 0
17
+ for i in range(batch_size):
18
+ for j in range(start, lengths[i] - end):
19
+ w = self.alphabets['word_alphabet'].get_instance(word[i, j])
20
+ p = self.alphabets['pos_alphabet'].get_instance(pos[i, j])
21
+ n = self.alphabets['ner_alphabet'].get_instance(ner[i, j])
22
+ t = self.alphabets['arc_alphabet'].get_instance(arc[i, j])
23
+ h = head[i, j]
24
+ if auto_label is not None:
25
+ m = self.alphabets['auto_label_alphabet'].get_instance(auto_label[i, j])
26
+ self.__source_file.write('%d\t%s\t%s\t%s\t%d\t%s\t%s\n' % (j, w, p, n, h, t, m))
27
+ else:
28
+ self.__source_file.write('%d\t%s\t%s\t%s\t%d\t%s\n' % (j, w, p, n, h, t))
29
+ self.__source_file.write('\n')
30
+
31
+ class Index2Instance(object):
32
+ def __init__(self, alphabet):
33
+ self.__alphabet = alphabet
34
+
35
+ def index2instance(self, indices, lengths, symbolic_root=False, symbolic_end=False):
36
+ batch_size, _ = indices.shape
37
+ start = 1 if symbolic_root else 0
38
+ end = 1 if symbolic_end else 0
39
+ instnaces = []
40
+ for i in range(batch_size):
41
+ tmp_instances = []
42
+ for j in range(start, lengths[i] - end):
43
+ instamce = self.__alphabet.get_instance(indices[i, j])
44
+ tmp_instances.append(instamce)
45
+ instnaces.append(tmp_instances)
46
+ return instnaces
utils/load_word_embeddings.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import numpy as np
3
+ from gensim.models import KeyedVectors
4
+ import gzip
5
+ import io
6
+ import os
7
+
8
+ def calc_mean_vec_for_lower_mapping(embedd_dict):
9
+ lower_counts = {}
10
+ for word in embedd_dict:
11
+ word_lower = word.lower()
12
+ if word_lower not in lower_counts:
13
+ lower_counts[word_lower] = [word]
14
+ else:
15
+ lower_counts[word_lower] = lower_counts[word_lower] + [word]
16
+ # calculating mean vector for all words that have the same mapping after performing lower()
17
+ for word in lower_counts:
18
+ embedd_dict[word] = np.mean([embedd_dict[word_] for word_ in lower_counts[word]])
19
+ return embedd_dict
20
+
21
+ def load_embedding_dict(embedding, embedding_path, lower_case=False):
22
+ """
23
+ load word embeddings from file
24
+ :param embedding:
25
+ :param embedding_path:
26
+ :return: embedding dict, embedding dimention, caseless
27
+ """
28
+ print("loading embedding: %s from %s" % (embedding, embedding_path))
29
+ if lower_case:
30
+ pkl_path = embedding_path + '_lower' + '.pkl'
31
+ else:
32
+ pkl_path = embedding_path + '.pkl'
33
+ if os.path.isfile(pkl_path):
34
+ # load dict and dim from a pickle file
35
+ with open(pkl_path, 'rb') as f:
36
+ embedd_dict, embedd_dim = pickle.load(f)
37
+ print("num dimensions of word embeddings:", embedd_dim)
38
+ return embedd_dict, embedd_dim
39
+
40
+ if embedding == 'glove':
41
+ # loading GloVe
42
+ embedd_dict = {}
43
+ word = None
44
+ with io.open(embedding_path, 'r', encoding='utf-8') as f:
45
+ for line in f:
46
+ word, vec = line.split(' ', 1)
47
+ embedd_dict[word] = np.fromstring(vec, sep=' ')
48
+ embedd_dim = len(embedd_dict[word])
49
+ if lower_case:
50
+ embedd_dict = calc_mean_vec_for_lower_mapping(embedd_dict)
51
+ for k, v in embedd_dict.items():
52
+ if len(v) != embedd_dim:
53
+ print(len(v),embedd_dim)
54
+
55
+ elif embedding == 'fasttext':
56
+ # loading GloVe
57
+ embedd_dict = {}
58
+ word = None
59
+ with io.open(embedding_path, 'r', encoding='utf-8') as f:
60
+ # skip first line
61
+ for i, line in enumerate(f):
62
+ if i == 0:
63
+ continue
64
+ word, vec = line.split(' ', 1)
65
+ embedd_dict[word] = np.fromstring(vec, sep=' ')
66
+ embedd_dim = len(embedd_dict[word])
67
+ if lower_case:
68
+ embedd_dict = calc_mean_vec_for_lower_mapping(embedd_dict)
69
+ for k, v in embedd_dict.items():
70
+ if len(v) != embedd_dim:
71
+ print(len(v),embedd_dim)
72
+
73
+ elif embedding == 'hellwig':
74
+ # loading hellwig
75
+ embedd_dict = {}
76
+ word = None
77
+ with io.open(embedding_path, 'r', encoding='utf-8') as f:
78
+ # skip first line
79
+ for i, line in enumerate(f):
80
+ if i == 0:
81
+ continue
82
+ word, vec = line.split(' ', 1)
83
+ embedd_dict[word] = np.fromstring(vec, sep=' ')
84
+ embedd_dim = len(embedd_dict[word])
85
+ if lower_case:
86
+ embedd_dict = calc_mean_vec_for_lower_mapping(embedd_dict)
87
+ for k, v in embedd_dict.items():
88
+ if len(v) != embedd_dim:
89
+ print(len(v),embedd_dim)
90
+
91
+ elif embedding == 'one_hot':
92
+ # loading hellwig
93
+ embedd_dict = {}
94
+ word = None
95
+ with io.open(embedding_path, 'r', encoding='utf-8') as f:
96
+ # skip first line
97
+ for i, line in enumerate(f):
98
+ if i == 0:
99
+ continue
100
+ word, vec = line.split('@', 1)
101
+ embedd_dict[word] = np.fromstring(vec, sep=' ')
102
+ embedd_dim = len(embedd_dict[word])
103
+ if lower_case:
104
+ embedd_dict = calc_mean_vec_for_lower_mapping(embedd_dict)
105
+ for k, v in embedd_dict.items():
106
+ if len(v) != embedd_dim:
107
+ print(len(v),embedd_dim)
108
+
109
+ elif embedding == 'word2vec':
110
+ # loading word2vec
111
+ embedd_dict = KeyedVectors.load_word2vec_format(embedding_path, binary=True)
112
+ if lower_case:
113
+ embedd_dict = calc_mean_vec_for_lower_mapping(embedd_dict)
114
+ embedd_dim = embedd_dict.vector_size
115
+
116
+ else:
117
+ raise ValueError("embedding should choose from [fasttext, glove, word2vec]")
118
+
119
+ print("num dimensions of word embeddings:", embedd_dim)
120
+ # save dict and dim to a pickle file
121
+ with open(pkl_path, 'wb') as f:
122
+ pickle.dump([embedd_dict, embedd_dim], f, pickle.HIGHEST_PROTOCOL)
123
+ return embedd_dict, embedd_dim
utils/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .parsing import *
2
+ from .parsing_gating import *
3
+ from .sequence_tagger import *
utils/models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (190 Bytes). View file
 
utils/models/__pycache__/parsing.cpython-37.pyc ADDED
Binary file (10.5 kB). View file
 
utils/models/__pycache__/parsing_gating.cpython-37.pyc ADDED
Binary file (13.7 kB). View file