Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- LICENSE.md +201 -0
- README.md +52 -3
- data/ud_pos_ner_dp_dev_san +0 -0
- data/ud_pos_ner_dp_dev_san_POS +0 -0
- data/ud_pos_ner_dp_dev_san_case +0 -0
- data/ud_pos_ner_dp_test_san +0 -0
- data/ud_pos_ner_dp_test_san_POS +0 -0
- data/ud_pos_ner_dp_test_san_case +0 -0
- data/ud_pos_ner_dp_train_san +0 -0
- data/ud_pos_ner_dp_train_san_POS +0 -0
- data/ud_pos_ner_dp_train_san_case +0 -0
- examples/GraphParser.py +703 -0
- examples/GraphParser_MRL.py +603 -0
- examples/SequenceTagger.py +597 -0
- examples/eval/conll03eval.v2 +336 -0
- examples/eval/conll06eval.pl +1826 -0
- examples/test_original_dcst.sh +110 -0
- run_san_LCM.sh +73 -0
- utils/__init__.py +7 -0
- utils/__pycache__/__init__.cpython-37.pyc +0 -0
- utils/__pycache__/load_word_embeddings.cpython-37.pyc +0 -0
- utils/io_/__init__.py +5 -0
- utils/io_/__pycache__/__init__.cpython-37.pyc +0 -0
- utils/io_/__pycache__/alphabet.cpython-37.pyc +0 -0
- utils/io_/__pycache__/instance.cpython-37.pyc +0 -0
- utils/io_/__pycache__/logger.cpython-37.pyc +0 -0
- utils/io_/__pycache__/prepare_data.cpython-37.pyc +0 -0
- utils/io_/__pycache__/reader.cpython-37.pyc +0 -0
- utils/io_/__pycache__/rearrange_splits.cpython-37.pyc +0 -0
- utils/io_/__pycache__/seeds.cpython-37.pyc +0 -0
- utils/io_/__pycache__/write_extra_labels.cpython-37.pyc +0 -0
- utils/io_/__pycache__/writer.cpython-37.pyc +0 -0
- utils/io_/alphabet.py +147 -0
- utils/io_/coarse_to_ma_dict.json +1 -0
- utils/io_/convert_ud_to_onto_format.py +74 -0
- utils/io_/instance.py +19 -0
- utils/io_/logger.py +15 -0
- utils/io_/prepare_data.py +397 -0
- utils/io_/reader.py +93 -0
- utils/io_/rearrange_splits.py +68 -0
- utils/io_/remove_xx.py +60 -0
- utils/io_/seeds.py +12 -0
- utils/io_/write_extra_labels.py +1592 -0
- utils/io_/writer.py +46 -0
- utils/load_word_embeddings.py +123 -0
- utils/models/__init__.py +3 -0
- utils/models/__pycache__/__init__.cpython-37.pyc +0 -0
- utils/models/__pycache__/parsing.cpython-37.pyc +0 -0
- utils/models/__pycache__/parsing_gating.cpython-37.pyc +0 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
saved_models
|
2 |
+
multilingual_word_embeddings
|
LICENSE.md
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2021 Jivnesh Sandhan
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,52 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Official code for the paper ["A Little Pretraining Goes a Long Way: A Case Study on Dependency Parsing Task for Low-resource Morphologically Rich Languages"](https://arxiv.org/abs/2102.06551).
|
2 |
+
If you use this code please cite our paper.
|
3 |
+
|
4 |
+
## Requirements
|
5 |
+
|
6 |
+
* Python 3.7
|
7 |
+
* Pytorch 1.1.0
|
8 |
+
* Cuda 9.0
|
9 |
+
* Gensim 3.8.1
|
10 |
+
|
11 |
+
We assume that you have installed conda beforehand.
|
12 |
+
|
13 |
+
```
|
14 |
+
conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=9.0 -c pytorch
|
15 |
+
pip install gensim==3.8.1
|
16 |
+
```
|
17 |
+
## Data
|
18 |
+
* Pretrained FastText embeddings for Sanskrit can be obtained from [here](https://drive.google.com/drive/folders/1JJMBjUZdqUY7WLYefBbA2zKaMHH3Mm18?usp=sharing). Make sure that `.vec` file is placed at approprite position.
|
19 |
+
* For Multilingual experiments, we use [UD treebanks](https://universaldependencies.org/) and [Pretrained FastText embeddings](https://fasttext.cc/docs/en/crawl-vectors.html)
|
20 |
+
|
21 |
+
|
22 |
+
## How to train model
|
23 |
+
If you want to run complete model pipeline: (1) Pretraining (2) Integration, then simply run bash script `run_san_LCM.sh`.
|
24 |
+
|
25 |
+
```bash
|
26 |
+
bash run_san_LCM.sh
|
27 |
+
|
28 |
+
```
|
29 |
+
|
30 |
+
|
31 |
+
## Citation
|
32 |
+
|
33 |
+
If you use our tool, we'd appreciate if you cite the following paper:
|
34 |
+
|
35 |
+
```
|
36 |
+
@inproceedings{sandhan-etal-2021-little,
|
37 |
+
title = "A Little Pretraining Goes a Long Way: A Case Study on Dependency Parsing Task for Low-resource Morphologically Rich Languages",
|
38 |
+
author = "Sandhan, Jivnesh and Krishna, Amrith and Gupta, Ashim and Behera, Laxmidhar and Goyal, Pawan",
|
39 |
+
booktitle = "Proceedings of the 16th Conference of the European Chapter of the Association for Computational Linguistics: Student Research Workshop",
|
40 |
+
month = apr,
|
41 |
+
year = "2021",
|
42 |
+
address = "Online",
|
43 |
+
publisher = "Association for Computational Linguistics",
|
44 |
+
url = "https://aclanthology.org/2021.eacl-srw.16",
|
45 |
+
doi = "10.18653/v1/2021.eacl-srw.16",
|
46 |
+
pages = "111--120",
|
47 |
+
abstract = "Neural dependency parsing has achieved remarkable performance for many domains and languages. The bottleneck of massive labelled data limits the effectiveness of these approaches for low resource languages. In this work, we focus on dependency parsing for morphological rich languages (MRLs) in a low-resource setting. Although morphological information is essential for the dependency parsing task, the morphological disambiguation and lack of powerful analyzers pose challenges to get this information for MRLs. To address these challenges, we propose simple auxiliary tasks for pretraining. We perform experiments on 10 MRLs in low-resource settings to measure the efficacy of our proposed pretraining method and observe an average absolute gain of 2 points (UAS) and 3.6 points (LAS).",
|
48 |
+
}
|
49 |
+
```
|
50 |
+
|
51 |
+
## Acknowledgements
|
52 |
+
Much of the base code is from ["DCST Implementation"](https://github.com/rotmanguy/DCST)
|
data/ud_pos_ner_dp_dev_san
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/ud_pos_ner_dp_dev_san_POS
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/ud_pos_ner_dp_dev_san_case
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/ud_pos_ner_dp_test_san
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/ud_pos_ner_dp_test_san_POS
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/ud_pos_ner_dp_test_san_case
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/ud_pos_ner_dp_train_san
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/ud_pos_ner_dp_train_san_POS
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/ud_pos_ner_dp_train_san_case
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/GraphParser.py
ADDED
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import sys
|
3 |
+
from os import path, makedirs
|
4 |
+
|
5 |
+
sys.path.append(".")
|
6 |
+
sys.path.append("..")
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
from copy import deepcopy
|
10 |
+
import json
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from collections import namedtuple
|
15 |
+
from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits
|
16 |
+
from utils.models.parsing_gating import BiAffine_Parser_Gated
|
17 |
+
from utils import load_word_embeddings
|
18 |
+
from utils.tasks import parse
|
19 |
+
import time
|
20 |
+
from torch.nn.utils import clip_grad_norm_
|
21 |
+
from torch.optim import Adam, SGD
|
22 |
+
import uuid
|
23 |
+
import pdb
|
24 |
+
uid = uuid.uuid4().hex[:6]
|
25 |
+
|
26 |
+
logger = get_logger('GraphParser')
|
27 |
+
|
28 |
+
def read_arguments():
|
29 |
+
args_ = argparse.ArgumentParser(description='Sovling GraphParser')
|
30 |
+
args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
|
31 |
+
args_.add_argument('--domain', help='domain/language', required=True)
|
32 |
+
args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
|
33 |
+
required=True)
|
34 |
+
args_.add_argument('--gating',action='store_true', help='use gated mechanism')
|
35 |
+
args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism')
|
36 |
+
args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
|
37 |
+
args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
|
38 |
+
args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
|
39 |
+
args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
|
40 |
+
args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space')
|
41 |
+
args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
|
42 |
+
args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
|
43 |
+
args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
|
44 |
+
args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
|
45 |
+
args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
|
46 |
+
args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
|
47 |
+
args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
|
48 |
+
args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
|
49 |
+
args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
|
50 |
+
args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
|
51 |
+
args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
|
52 |
+
args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
|
53 |
+
args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
|
54 |
+
args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
|
55 |
+
args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
|
56 |
+
args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
|
57 |
+
args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
|
58 |
+
args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
|
59 |
+
args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
|
60 |
+
args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
|
61 |
+
args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
|
62 |
+
args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True)
|
63 |
+
args_.add_argument('--unk_replace', type=float, default=0.,
|
64 |
+
help='The rate to replace a singleton word with UNK')
|
65 |
+
args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
|
66 |
+
args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
|
67 |
+
help='Embedding for words')
|
68 |
+
args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
|
69 |
+
args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
|
70 |
+
args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.')
|
71 |
+
args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
|
72 |
+
required=True)
|
73 |
+
args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
|
74 |
+
required=True)
|
75 |
+
args_.add_argument('--char_path', help='path for character embedding dict')
|
76 |
+
args_.add_argument('--pos_path', help='path for pos embedding dict')
|
77 |
+
args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples')
|
78 |
+
args_.add_argument('--model_path', help='path for saving model file.', required=True)
|
79 |
+
args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
|
80 |
+
args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None)
|
81 |
+
args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin '
|
82 |
+
'exactly the same keys as current model')
|
83 |
+
args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
|
84 |
+
args_.add_argument('--eval_with_CI', action='store_true', help='evaluating model in constrained inference mode')
|
85 |
+
args_.add_argument('--LCM_Path_flag', action='store_true', help='for constrained inference with LCM, flag is used to change path')
|
86 |
+
args = args_.parse_args()
|
87 |
+
args_dict = {}
|
88 |
+
args_dict['dataset'] = args.dataset
|
89 |
+
args_dict['domain'] = args.domain
|
90 |
+
args_dict['rnn_mode'] = args.rnn_mode
|
91 |
+
args_dict['gating'] = args.gating
|
92 |
+
args_dict['num_gates'] = args.num_gates
|
93 |
+
args_dict['arc_decode'] = args.arc_decode
|
94 |
+
# args_dict['splits'] = ['train', 'dev', 'test']
|
95 |
+
args_dict['splits'] = ['train', 'dev', 'test']
|
96 |
+
args_dict['model_path'] = args.model_path
|
97 |
+
if not path.exists(args_dict['model_path']):
|
98 |
+
makedirs(args_dict['model_path'])
|
99 |
+
args_dict['data_paths'] = {}
|
100 |
+
if args_dict['dataset'] == 'ontonotes':
|
101 |
+
data_path = 'data/onto_pos_ner_dp'
|
102 |
+
else:
|
103 |
+
data_path = 'data/ud_pos_ner_dp'
|
104 |
+
for split in args_dict['splits']:
|
105 |
+
args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain']
|
106 |
+
###################################
|
107 |
+
|
108 |
+
###################################
|
109 |
+
args_dict['alphabet_data_paths'] = {}
|
110 |
+
for split in args_dict['splits']:
|
111 |
+
if args_dict['dataset'] == 'ontonotes':
|
112 |
+
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all'
|
113 |
+
else:
|
114 |
+
if '_' in args_dict['domain']:
|
115 |
+
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
|
116 |
+
else:
|
117 |
+
args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split]
|
118 |
+
args_dict['model_name'] = 'domain_' + args_dict['domain']
|
119 |
+
args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
|
120 |
+
args_dict['load_path'] = args.load_path
|
121 |
+
args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths
|
122 |
+
if args_dict['load_sequence_taggers_paths'] is not None:
|
123 |
+
args_dict['gating'] = True
|
124 |
+
args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1
|
125 |
+
else:
|
126 |
+
if not args_dict['gating']:
|
127 |
+
args_dict['num_gates'] = 0
|
128 |
+
args_dict['strict'] = args.strict
|
129 |
+
args_dict['num_epochs'] = args.num_epochs
|
130 |
+
args_dict['batch_size'] = args.batch_size
|
131 |
+
args_dict['hidden_size'] = args.hidden_size
|
132 |
+
args_dict['arc_space'] = args.arc_space
|
133 |
+
args_dict['arc_tag_space'] = args.arc_tag_space
|
134 |
+
args_dict['num_layers'] = args.num_layers
|
135 |
+
args_dict['num_filters'] = args.num_filters
|
136 |
+
args_dict['kernel_size'] = args.kernel_size
|
137 |
+
args_dict['learning_rate'] = args.learning_rate
|
138 |
+
args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
|
139 |
+
args_dict['opt'] = args.opt
|
140 |
+
args_dict['momentum'] = args.momentum
|
141 |
+
args_dict['betas'] = tuple(args.betas)
|
142 |
+
args_dict['epsilon'] = args.epsilon
|
143 |
+
args_dict['decay_rate'] = args.decay_rate
|
144 |
+
args_dict['clip'] = args.clip
|
145 |
+
args_dict['gamma'] = args.gamma
|
146 |
+
args_dict['schedule'] = args.schedule
|
147 |
+
args_dict['p_rnn'] = tuple(args.p_rnn)
|
148 |
+
args_dict['p_in'] = args.p_in
|
149 |
+
args_dict['p_out'] = args.p_out
|
150 |
+
args_dict['unk_replace'] = args.unk_replace
|
151 |
+
args_dict['set_num_training_samples'] = args.set_num_training_samples
|
152 |
+
args_dict['punct_set'] = None
|
153 |
+
if args.punct_set is not None:
|
154 |
+
args_dict['punct_set'] = set(args.punct_set)
|
155 |
+
logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
|
156 |
+
args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
|
157 |
+
args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers
|
158 |
+
args_dict['word_embedding'] = args.word_embedding
|
159 |
+
args_dict['word_path'] = args.word_path
|
160 |
+
args_dict['use_char'] = args.use_char
|
161 |
+
args_dict['char_embedding'] = args.char_embedding
|
162 |
+
args_dict['char_path'] = args.char_path
|
163 |
+
args_dict['pos_embedding'] = args.pos_embedding
|
164 |
+
args_dict['pos_path'] = args.pos_path
|
165 |
+
args_dict['use_pos'] = args.use_pos
|
166 |
+
args_dict['pos_dim'] = args.pos_dim
|
167 |
+
args_dict['word_dict'] = None
|
168 |
+
args_dict['word_dim'] = args.word_dim
|
169 |
+
if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
|
170 |
+
args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
|
171 |
+
args_dict['word_path'])
|
172 |
+
args_dict['char_dict'] = None
|
173 |
+
args_dict['char_dim'] = args.char_dim
|
174 |
+
if args_dict['char_embedding'] != 'random':
|
175 |
+
args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
|
176 |
+
args_dict['char_path'])
|
177 |
+
args_dict['pos_dict'] = None
|
178 |
+
if args_dict['pos_embedding'] != 'random':
|
179 |
+
args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
|
180 |
+
args_dict['pos_path'])
|
181 |
+
args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
|
182 |
+
args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
|
183 |
+
args_dict['eval_mode'] = args.eval_mode
|
184 |
+
args_dict['eval_with_CI'] = args.eval_with_CI
|
185 |
+
args_dict['LCM_Path_flag'] = args.LCM_Path_flag
|
186 |
+
args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
187 |
+
args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
|
188 |
+
args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
|
189 |
+
args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
|
190 |
+
logger.info("Saving arguments to file")
|
191 |
+
save_args(args, args_dict['full_model_name'])
|
192 |
+
logger.info("Creating Alphabets")
|
193 |
+
alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict'])
|
194 |
+
args_dict = {**args_dict, **alphabet_dict}
|
195 |
+
ARGS = namedtuple('ARGS', args_dict.keys())
|
196 |
+
my_args = ARGS(**args_dict)
|
197 |
+
return my_args
|
198 |
+
|
199 |
+
|
200 |
+
def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict):
|
201 |
+
train_paths = alphabet_data_paths['train']
|
202 |
+
extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train']
|
203 |
+
alphabet_dict = {}
|
204 |
+
alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path,
|
205 |
+
train_paths,
|
206 |
+
extra_paths=extra_paths,
|
207 |
+
max_vocabulary_size=100000,
|
208 |
+
embedd_dict=word_dict)
|
209 |
+
for k, v in alphabet_dict['alphabets'].items():
|
210 |
+
num_key = 'num_' + k.split('_')[0]
|
211 |
+
alphabet_dict[num_key] = v.size()
|
212 |
+
logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
|
213 |
+
return alphabet_dict
|
214 |
+
|
215 |
+
def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
|
216 |
+
if tokens_dict is None:
|
217 |
+
return None
|
218 |
+
scale = np.sqrt(3.0 / dim)
|
219 |
+
table = np.empty([alphabet.size(), dim], dtype=np.float32)
|
220 |
+
table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
221 |
+
oov_tokens = 0
|
222 |
+
for token, index in alphabet.items():
|
223 |
+
if token in tokens_dict:
|
224 |
+
embedding = tokens_dict[token]
|
225 |
+
elif token.lower() in tokens_dict:
|
226 |
+
embedding = tokens_dict[token.lower()]
|
227 |
+
else:
|
228 |
+
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
229 |
+
oov_tokens += 1
|
230 |
+
table[index, :] = embedding
|
231 |
+
print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
|
232 |
+
table = torch.from_numpy(table)
|
233 |
+
return table
|
234 |
+
|
235 |
+
def save_args(args, full_model_name):
|
236 |
+
arg_path = full_model_name + '.arg.json'
|
237 |
+
argparse_dict = vars(args)
|
238 |
+
with open(arg_path, 'w') as f:
|
239 |
+
json.dump(argparse_dict, f)
|
240 |
+
|
241 |
+
def generate_optimizer(args, lr, params):
|
242 |
+
params = filter(lambda param: param.requires_grad, params)
|
243 |
+
if args.opt == 'adam':
|
244 |
+
return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
|
245 |
+
elif args.opt == 'sgd':
|
246 |
+
return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
|
247 |
+
else:
|
248 |
+
raise ValueError('Unknown optimization algorithm: %s' % args.opt)
|
249 |
+
|
250 |
+
|
251 |
+
def save_checkpoint(args, model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
|
252 |
+
path_name = full_model_name + '.pt'
|
253 |
+
print('Saving model to: %s' % path_name)
|
254 |
+
state = {'model_state_dict': model.state_dict(),
|
255 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
256 |
+
'opt': opt,
|
257 |
+
'dev_eval_dict': dev_eval_dict,
|
258 |
+
'test_eval_dict': test_eval_dict}
|
259 |
+
torch.save(state, path_name)
|
260 |
+
|
261 |
+
def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
|
262 |
+
print('Loading saved model from: %s' % load_path)
|
263 |
+
checkpoint = torch.load(load_path, map_location=args.device)
|
264 |
+
if checkpoint['opt'] != args.opt:
|
265 |
+
raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
|
266 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
|
267 |
+
|
268 |
+
|
269 |
+
if strict:
|
270 |
+
generate_optimizer(args, args.learning_rate, model.parameters())
|
271 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
272 |
+
for state in optimizer.state.values():
|
273 |
+
for k, v in state.items():
|
274 |
+
if isinstance(v, torch.Tensor):
|
275 |
+
state[k] = v.to(args.device)
|
276 |
+
dev_eval_dict = checkpoint['dev_eval_dict']
|
277 |
+
test_eval_dict = checkpoint['test_eval_dict']
|
278 |
+
start_epoch = dev_eval_dict['in_domain']['epoch']
|
279 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
280 |
+
|
281 |
+
|
282 |
+
def build_model_and_optimizer(args):
|
283 |
+
word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
|
284 |
+
char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
|
285 |
+
pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
|
286 |
+
model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char,
|
287 |
+
args.use_pos, args.use_char, args.pos_dim, args.num_pos,
|
288 |
+
args.num_filters, args.kernel_size, args.rnn_mode,
|
289 |
+
args.hidden_size, args.num_layers, args.num_arc,
|
290 |
+
args.arc_space, args.arc_tag_space, args.num_gates,
|
291 |
+
embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
|
292 |
+
p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
|
293 |
+
biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer)
|
294 |
+
print(model)
|
295 |
+
optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
|
296 |
+
start_epoch = 0
|
297 |
+
dev_eval_dict = {'in_domain': initialize_eval_dict()}
|
298 |
+
test_eval_dict = {'in_domain': initialize_eval_dict()}
|
299 |
+
if args.load_path:
|
300 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
|
301 |
+
load_checkpoint(args, model, optimizer,
|
302 |
+
dev_eval_dict, test_eval_dict,
|
303 |
+
start_epoch, args.load_path, strict=args.strict)
|
304 |
+
if args.load_sequence_taggers_paths:
|
305 |
+
pretrained_dict = {}
|
306 |
+
model_dict = model.state_dict()
|
307 |
+
for idx, path in enumerate(args.load_sequence_taggers_paths):
|
308 |
+
print('Loading saved sequence_tagger from: %s' % path)
|
309 |
+
checkpoint = torch.load(path, map_location=args.device)
|
310 |
+
for k, v in checkpoint['model_state_dict'].items():
|
311 |
+
if 'rnn_encoder.' in k:
|
312 |
+
pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v
|
313 |
+
model_dict.update(pretrained_dict)
|
314 |
+
model.load_state_dict(model_dict)
|
315 |
+
if args.freeze_sequence_taggers:
|
316 |
+
print('Freezing Classifiers')
|
317 |
+
for name, parameter in model.named_parameters():
|
318 |
+
if 'extra_rnn_encoders' in name:
|
319 |
+
parameter.requires_grad = False
|
320 |
+
if args.freeze_word_embeddings:
|
321 |
+
model.rnn_encoder.word_embedd.weight.requires_grad = False
|
322 |
+
# model.rnn_encoder.char_embedd.weight.requires_grad = False
|
323 |
+
# model.rnn_encoder.pos_embedd.weight.requires_grad = False
|
324 |
+
device = args.device
|
325 |
+
model.to(device)
|
326 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
327 |
+
|
328 |
+
|
329 |
+
def initialize_eval_dict():
|
330 |
+
eval_dict = {}
|
331 |
+
eval_dict['dp_uas'] = 0.0
|
332 |
+
eval_dict['dp_las'] = 0.0
|
333 |
+
eval_dict['epoch'] = 0
|
334 |
+
eval_dict['dp_ucorrect'] = 0.0
|
335 |
+
eval_dict['dp_lcorrect'] = 0.0
|
336 |
+
eval_dict['dp_total'] = 0.0
|
337 |
+
eval_dict['dp_ucomplete_match'] = 0.0
|
338 |
+
eval_dict['dp_lcomplete_match'] = 0.0
|
339 |
+
eval_dict['dp_ucorrect_nopunc'] = 0.0
|
340 |
+
eval_dict['dp_lcorrect_nopunc'] = 0.0
|
341 |
+
eval_dict['dp_total_nopunc'] = 0.0
|
342 |
+
eval_dict['dp_ucomplete_match_nopunc'] = 0.0
|
343 |
+
eval_dict['dp_lcomplete_match_nopunc'] = 0.0
|
344 |
+
eval_dict['dp_root_correct'] = 0.0
|
345 |
+
eval_dict['dp_total_root'] = 0.0
|
346 |
+
eval_dict['dp_total_inst'] = 0.0
|
347 |
+
eval_dict['dp_total'] = 0.0
|
348 |
+
eval_dict['dp_total_inst'] = 0.0
|
349 |
+
eval_dict['dp_total_nopunc'] = 0.0
|
350 |
+
eval_dict['dp_total_root'] = 0.0
|
351 |
+
return eval_dict
|
352 |
+
|
353 |
+
def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
|
354 |
+
best_model, best_optimizer, patient):
|
355 |
+
# In-domain evaluation
|
356 |
+
curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
|
357 |
+
is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \
|
358 |
+
(dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and
|
359 |
+
dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc'])
|
360 |
+
|
361 |
+
if is_best_in_domain:
|
362 |
+
for key, value in curr_dev_eval_dict.items():
|
363 |
+
dev_eval_dict['in_domain'][key] = value
|
364 |
+
curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
|
365 |
+
for key, value in curr_test_eval_dict.items():
|
366 |
+
test_eval_dict['in_domain'][key] = value
|
367 |
+
best_model = deepcopy(model)
|
368 |
+
best_optimizer = deepcopy(optimizer)
|
369 |
+
patient = 0
|
370 |
+
else:
|
371 |
+
patient += 1
|
372 |
+
if epoch == args.num_epochs:
|
373 |
+
# save in-domain checkpoint
|
374 |
+
if args.set_num_training_samples is not None:
|
375 |
+
splits_to_write = datasets.keys()
|
376 |
+
else:
|
377 |
+
splits_to_write = ['dev', 'test']
|
378 |
+
for split in splits_to_write:
|
379 |
+
if split == 'dev':
|
380 |
+
eval_dict = dev_eval_dict['in_domain']
|
381 |
+
elif split == 'test':
|
382 |
+
eval_dict = test_eval_dict['in_domain']
|
383 |
+
else:
|
384 |
+
eval_dict = None
|
385 |
+
write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
|
386 |
+
print("Saving best model")
|
387 |
+
save_checkpoint(args, best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
|
388 |
+
|
389 |
+
print('\n')
|
390 |
+
return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient
|
391 |
+
|
392 |
+
|
393 |
+
def evaluation(args, data, split, model, domain, epoch, str_res='results'):
|
394 |
+
# evaluate performance on data
|
395 |
+
model.eval()
|
396 |
+
|
397 |
+
eval_dict = initialize_eval_dict()
|
398 |
+
eval_dict['epoch'] = epoch
|
399 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
400 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
401 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
402 |
+
heads_pred, arc_tags_pred, _ = model.decode(args.model_path, word, pos,ner, out_arc, out_arc_tag, mask=masks, length=lengths,
|
403 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
404 |
+
lengths = lengths.cpu().numpy()
|
405 |
+
word = word.data.cpu().numpy()
|
406 |
+
pos = pos.data.cpu().numpy()
|
407 |
+
ner = ner.data.cpu().numpy()
|
408 |
+
heads = heads.data.cpu().numpy()
|
409 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
410 |
+
heads_pred = heads_pred.data.cpu().numpy()
|
411 |
+
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
|
412 |
+
stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
|
413 |
+
arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
|
414 |
+
lengths, punct_set=args.punct_set, symbolic_root=True)
|
415 |
+
ucorr, lcorr, total, ucm, lcm = stats
|
416 |
+
ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
|
417 |
+
corr_root, total_root = stats_root
|
418 |
+
eval_dict['dp_ucorrect'] += ucorr
|
419 |
+
eval_dict['dp_lcorrect'] += lcorr
|
420 |
+
eval_dict['dp_total'] += total
|
421 |
+
eval_dict['dp_ucomplete_match'] += ucm
|
422 |
+
eval_dict['dp_lcomplete_match'] += lcm
|
423 |
+
eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
|
424 |
+
eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
|
425 |
+
eval_dict['dp_total_nopunc'] += total_nopunc
|
426 |
+
eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
|
427 |
+
eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
|
428 |
+
eval_dict['dp_root_correct'] += corr_root
|
429 |
+
eval_dict['dp_total_root'] += total_root
|
430 |
+
eval_dict['dp_total_inst'] += num_inst
|
431 |
+
|
432 |
+
eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
|
433 |
+
eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
|
434 |
+
print_results(eval_dict, split, domain, str_res)
|
435 |
+
return eval_dict
|
436 |
+
|
437 |
+
def constrained_evaluation(args, data, split, model, domain, epoch, str_res='results'):
|
438 |
+
# evaluate performance on data
|
439 |
+
model.eval()
|
440 |
+
|
441 |
+
eval_dict = initialize_eval_dict()
|
442 |
+
eval_dict['epoch'] = epoch
|
443 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
444 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
445 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
446 |
+
heads_pred, arc_tags_pred, _ = model.constrained_decode(args, word, pos,ner, out_arc, out_arc_tag, mask=masks, length=lengths,
|
447 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
448 |
+
lengths = lengths.cpu().numpy()
|
449 |
+
word = word.data.cpu().numpy()
|
450 |
+
pos = pos.data.cpu().numpy()
|
451 |
+
ner = ner.data.cpu().numpy()
|
452 |
+
heads = heads.data.cpu().numpy()
|
453 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
454 |
+
heads_pred = heads_pred.data.cpu().numpy()
|
455 |
+
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
|
456 |
+
stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
|
457 |
+
arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
|
458 |
+
lengths, punct_set=args.punct_set, symbolic_root=True)
|
459 |
+
ucorr, lcorr, total, ucm, lcm = stats
|
460 |
+
ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
|
461 |
+
corr_root, total_root = stats_root
|
462 |
+
eval_dict['dp_ucorrect'] += ucorr
|
463 |
+
eval_dict['dp_lcorrect'] += lcorr
|
464 |
+
eval_dict['dp_total'] += total
|
465 |
+
eval_dict['dp_ucomplete_match'] += ucm
|
466 |
+
eval_dict['dp_lcomplete_match'] += lcm
|
467 |
+
eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
|
468 |
+
eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
|
469 |
+
eval_dict['dp_total_nopunc'] += total_nopunc
|
470 |
+
eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
|
471 |
+
eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
|
472 |
+
eval_dict['dp_root_correct'] += corr_root
|
473 |
+
eval_dict['dp_total_root'] += total_root
|
474 |
+
eval_dict['dp_total_inst'] += num_inst
|
475 |
+
|
476 |
+
eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
|
477 |
+
eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
|
478 |
+
print_results(eval_dict, split, domain, str_res)
|
479 |
+
return eval_dict
|
480 |
+
|
481 |
+
def print_results(eval_dict, split, domain, str_res='results'):
|
482 |
+
print('----------------------------------------------------------------------------------------------------------------------------')
|
483 |
+
print('Testing model on domain %s' % domain)
|
484 |
+
print('--------------- Dependency Parsing - %s ---------------' % split)
|
485 |
+
print(
|
486 |
+
str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
|
487 |
+
eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'],
|
488 |
+
eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'],
|
489 |
+
eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'],
|
490 |
+
eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'],
|
491 |
+
eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'],
|
492 |
+
eval_dict['epoch']))
|
493 |
+
print(
|
494 |
+
str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
|
495 |
+
eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'],
|
496 |
+
eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
|
497 |
+
eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
|
498 |
+
eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
|
499 |
+
eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
|
500 |
+
eval_dict['epoch']))
|
501 |
+
print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
|
502 |
+
eval_dict['dp_root_correct'], eval_dict['dp_total_root'],
|
503 |
+
eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch']))
|
504 |
+
print('\n')
|
505 |
+
def constrained_write_results(args, data, data_domain, split, model, model_domain, eval_dict):
|
506 |
+
str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
|
507 |
+
res_filename = str_file + '_res.txt'
|
508 |
+
pred_filename = str_file + '_pred.txt'
|
509 |
+
gold_filename = str_file + '_gold.txt'
|
510 |
+
if eval_dict is not None:
|
511 |
+
# save results dictionary into a file
|
512 |
+
with open(res_filename, 'w') as f:
|
513 |
+
json.dump(eval_dict, f)
|
514 |
+
|
515 |
+
# save predictions and gold labels into files
|
516 |
+
pred_writer = Writer(args.alphabets)
|
517 |
+
gold_writer = Writer(args.alphabets)
|
518 |
+
pred_writer.start(pred_filename)
|
519 |
+
gold_writer.start(gold_filename)
|
520 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
521 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
522 |
+
# pdb.set_trace()
|
523 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
524 |
+
heads_pred, arc_tags_pred, _ = model.constrained_decode(args,word, pos, ner, out_arc, out_arc_tag, mask=masks, length=lengths,
|
525 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
526 |
+
lengths = lengths.cpu().numpy()
|
527 |
+
word = word.data.cpu().numpy()
|
528 |
+
pos = pos.data.cpu().numpy()
|
529 |
+
ner = ner.data.cpu().numpy()
|
530 |
+
heads = heads.data.cpu().numpy()
|
531 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
532 |
+
heads_pred = heads_pred.data.cpu().numpy()
|
533 |
+
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
|
534 |
+
# print('words',word)
|
535 |
+
# print('Pos',pos)
|
536 |
+
|
537 |
+
# print('heads_pred',heads_pred)
|
538 |
+
# print('arc_tags_pred',arc_tags_pred)
|
539 |
+
# pdb.set_trace()
|
540 |
+
# writing predictions
|
541 |
+
pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
|
542 |
+
# writing gold labels
|
543 |
+
gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
|
544 |
+
|
545 |
+
pred_writer.close()
|
546 |
+
gold_writer.close()
|
547 |
+
def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
|
548 |
+
str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
|
549 |
+
res_filename = str_file + '_res.txt'
|
550 |
+
pred_filename = str_file + '_pred.txt'
|
551 |
+
gold_filename = str_file + '_gold.txt'
|
552 |
+
if eval_dict is not None:
|
553 |
+
# save results dictionary into a file
|
554 |
+
with open(res_filename, 'w') as f:
|
555 |
+
json.dump(eval_dict, f)
|
556 |
+
|
557 |
+
# save predictions and gold labels into files
|
558 |
+
pred_writer = Writer(args.alphabets)
|
559 |
+
gold_writer = Writer(args.alphabets)
|
560 |
+
pred_writer.start(pred_filename)
|
561 |
+
gold_writer.start(gold_filename)
|
562 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
563 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
564 |
+
# pdb.set_trace()
|
565 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
566 |
+
heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos, ner, out_arc, out_arc_tag, mask=masks, length=lengths,
|
567 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
568 |
+
lengths = lengths.cpu().numpy()
|
569 |
+
word = word.data.cpu().numpy()
|
570 |
+
pos = pos.data.cpu().numpy()
|
571 |
+
ner = ner.data.cpu().numpy()
|
572 |
+
heads = heads.data.cpu().numpy()
|
573 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
574 |
+
heads_pred = heads_pred.data.cpu().numpy()
|
575 |
+
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
|
576 |
+
# print('words',word)
|
577 |
+
# print('Pos',pos)
|
578 |
+
|
579 |
+
# print('heads_pred',heads_pred)
|
580 |
+
# print('arc_tags_pred',arc_tags_pred)
|
581 |
+
# pdb.set_trace()
|
582 |
+
# writing predictions
|
583 |
+
pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
|
584 |
+
# writing gold labels
|
585 |
+
gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
|
586 |
+
|
587 |
+
pred_writer.close()
|
588 |
+
gold_writer.close()
|
589 |
+
|
590 |
+
def main():
|
591 |
+
logger.info("Reading and creating arguments")
|
592 |
+
args = read_arguments()
|
593 |
+
logger.info("Reading Data")
|
594 |
+
datasets = {}
|
595 |
+
for split in args.splits:
|
596 |
+
print("Splits are:",split)
|
597 |
+
dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
|
598 |
+
symbolic_root=True)
|
599 |
+
datasets[split] = dataset
|
600 |
+
if args.set_num_training_samples is not None:
|
601 |
+
print('Setting train and dev to %d samples' % args.set_num_training_samples)
|
602 |
+
datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples)
|
603 |
+
logger.info("Creating Networks")
|
604 |
+
num_data = sum(datasets['train'][1])
|
605 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
|
606 |
+
best_model = deepcopy(model)
|
607 |
+
best_optimizer = deepcopy(optimizer)
|
608 |
+
|
609 |
+
logger.info('Training INFO of in domain %s' % args.domain)
|
610 |
+
logger.info('Training on Dependecy Parsing')
|
611 |
+
logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
|
612 |
+
logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
|
613 |
+
logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
|
614 |
+
logger.info("num_epochs: %d" % (args.num_epochs))
|
615 |
+
print('\n')
|
616 |
+
|
617 |
+
if not args.eval_mode:
|
618 |
+
logger.info("Training")
|
619 |
+
num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
|
620 |
+
lr = args.learning_rate
|
621 |
+
patient = 0
|
622 |
+
decay = 0
|
623 |
+
for epoch in range(start_epoch + 1, args.num_epochs + 1):
|
624 |
+
print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
|
625 |
+
epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
|
626 |
+
model.train()
|
627 |
+
total_loss = 0.0
|
628 |
+
total_arc_loss = 0.0
|
629 |
+
total_arc_tag_loss = 0.0
|
630 |
+
total_train_inst = 0.0
|
631 |
+
|
632 |
+
train_iter = prepare_data.iterate_batch_rand_bucket_choosing(
|
633 |
+
datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
|
634 |
+
start_time = time.time()
|
635 |
+
batch_num = 0
|
636 |
+
for batch_num, batch in enumerate(train_iter):
|
637 |
+
batch_num = batch_num + 1
|
638 |
+
optimizer.zero_grad()
|
639 |
+
# compute loss of main task
|
640 |
+
word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
|
641 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
642 |
+
loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths)
|
643 |
+
loss = loss_arc + loss_arc_tag
|
644 |
+
# pdb.set_trace()
|
645 |
+
|
646 |
+
# update losses
|
647 |
+
num_insts = masks.data.sum() - word.size(0)
|
648 |
+
total_arc_loss += loss_arc.item() * num_insts
|
649 |
+
total_arc_tag_loss += loss_arc_tag.item() * num_insts
|
650 |
+
total_loss += loss.item() * num_insts
|
651 |
+
total_train_inst += num_insts
|
652 |
+
# optimize parameters
|
653 |
+
loss.backward()
|
654 |
+
clip_grad_norm_(model.parameters(), args.clip)
|
655 |
+
optimizer.step()
|
656 |
+
|
657 |
+
time_ave = (time.time() - start_time) / batch_num
|
658 |
+
time_left = (num_batches - batch_num) * time_ave
|
659 |
+
|
660 |
+
# update log
|
661 |
+
if batch_num % 50 == 0:
|
662 |
+
log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time left: %.2fs' % \
|
663 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
|
664 |
+
total_arc_tag_loss / total_train_inst, time_left)
|
665 |
+
sys.stdout.write(log_info)
|
666 |
+
sys.stdout.write('\n')
|
667 |
+
sys.stdout.flush()
|
668 |
+
print('\n')
|
669 |
+
print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time: %.2fs' %
|
670 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
|
671 |
+
total_arc_tag_loss / total_train_inst, time.time() - start_time))
|
672 |
+
|
673 |
+
dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
|
674 |
+
if patient >= args.schedule:
|
675 |
+
lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
|
676 |
+
optimizer = generate_optimizer(args, lr, model.parameters())
|
677 |
+
print('updated learning rate to %.6f' % lr)
|
678 |
+
patient = 0
|
679 |
+
print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
|
680 |
+
print('\n')
|
681 |
+
for split in datasets.keys():
|
682 |
+
if args.eval_with_CI and split not in ['train', 'extra_train', 'extra_dev']:
|
683 |
+
print('Currently going on ... ',split)
|
684 |
+
eval_dict = constrained_evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
|
685 |
+
constrained_write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
|
686 |
+
else:
|
687 |
+
eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
|
688 |
+
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
|
689 |
+
|
690 |
+
else:
|
691 |
+
logger.info("Evaluating")
|
692 |
+
epoch = start_epoch
|
693 |
+
for split in ['train','dev','test']:
|
694 |
+
if args.eval_with_CI and split != 'train':
|
695 |
+
eval_dict = constrained_evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
|
696 |
+
constrained_write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
|
697 |
+
else:
|
698 |
+
eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
|
699 |
+
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
|
700 |
+
|
701 |
+
|
702 |
+
if __name__ == '__main__':
|
703 |
+
main()
|
examples/GraphParser_MRL.py
ADDED
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import sys
|
3 |
+
from os import path, makedirs
|
4 |
+
|
5 |
+
sys.path.append(".")
|
6 |
+
sys.path.append("..")
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
from copy import deepcopy
|
10 |
+
import json
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from collections import namedtuple
|
15 |
+
from utils.io_ import seeds, Writer, get_logger, prepare_data, rearrange_splits
|
16 |
+
from utils.models.parsing_gating import BiAffine_Parser_Gated
|
17 |
+
from utils import load_word_embeddings
|
18 |
+
from utils.tasks import parse
|
19 |
+
import time
|
20 |
+
from torch.nn.utils import clip_grad_norm_
|
21 |
+
from torch.optim import Adam, SGD
|
22 |
+
import uuid
|
23 |
+
|
24 |
+
uid = uuid.uuid4().hex[:6]
|
25 |
+
|
26 |
+
logger = get_logger('GraphParser')
|
27 |
+
|
28 |
+
def read_arguments():
|
29 |
+
args_ = argparse.ArgumentParser(description='Sovling GraphParser')
|
30 |
+
args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
|
31 |
+
args_.add_argument('--domain', help='domain/language', required=True)
|
32 |
+
args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
|
33 |
+
required=True)
|
34 |
+
args_.add_argument('--gating',action='store_true', help='use gated mechanism')
|
35 |
+
args_.add_argument('--num_gates', type=int, default=0, help='number of gates for gating mechanism')
|
36 |
+
args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
|
37 |
+
args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
|
38 |
+
args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
|
39 |
+
args_.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
|
40 |
+
args_.add_argument('--arc_tag_space', type=int, default=128, help='Dimension of tag space')
|
41 |
+
args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
|
42 |
+
args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
|
43 |
+
args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
|
44 |
+
args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
|
45 |
+
args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
|
46 |
+
args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
|
47 |
+
args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
|
48 |
+
args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
|
49 |
+
args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
|
50 |
+
args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
|
51 |
+
args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
|
52 |
+
args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
|
53 |
+
args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
|
54 |
+
args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
|
55 |
+
args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
|
56 |
+
args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
|
57 |
+
args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
|
58 |
+
args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
|
59 |
+
args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
|
60 |
+
args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
|
61 |
+
args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
|
62 |
+
args_.add_argument('--arc_decode', choices=['mst', 'greedy'], help='arc decoding algorithm', required=True)
|
63 |
+
args_.add_argument('--unk_replace', type=float, default=0.,
|
64 |
+
help='The rate to replace a singleton word with UNK')
|
65 |
+
args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
|
66 |
+
args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
|
67 |
+
help='Embedding for words')
|
68 |
+
args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
|
69 |
+
args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
|
70 |
+
args_.add_argument('--freeze_sequence_taggers', action='store_true', help='frozen the BiLSTMs of the pre-trained taggers.')
|
71 |
+
args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
|
72 |
+
required=True)
|
73 |
+
args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
|
74 |
+
required=True)
|
75 |
+
args_.add_argument('--char_path', help='path for character embedding dict')
|
76 |
+
args_.add_argument('--pos_path', help='path for pos embedding dict')
|
77 |
+
args_.add_argument('--set_num_training_samples', type=int, help='downsampling training set to a fixed number of samples')
|
78 |
+
args_.add_argument('--model_path', help='path for saving model file.', required=True)
|
79 |
+
args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
|
80 |
+
args_.add_argument('--load_sequence_taggers_paths', nargs='+', help='path for loading saved sequence_tagger saved_models files.', default=None)
|
81 |
+
args_.add_argument('--strict',action='store_true', help='if True loaded model state should contin '
|
82 |
+
'exactly the same keys as current model')
|
83 |
+
args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
|
84 |
+
args = args_.parse_args()
|
85 |
+
args_dict = {}
|
86 |
+
args_dict['dataset'] = args.dataset
|
87 |
+
args_dict['domain'] = args.domain
|
88 |
+
args_dict['rnn_mode'] = args.rnn_mode
|
89 |
+
args_dict['gating'] = args.gating
|
90 |
+
args_dict['num_gates'] = args.num_gates
|
91 |
+
args_dict['arc_decode'] = args.arc_decode
|
92 |
+
# args_dict['splits'] = ['train', 'dev', 'test']
|
93 |
+
args_dict['splits'] = ['train', 'dev', 'test','poetry','prose']
|
94 |
+
args_dict['model_path'] = args.model_path
|
95 |
+
if not path.exists(args_dict['model_path']):
|
96 |
+
makedirs(args_dict['model_path'])
|
97 |
+
args_dict['data_paths'] = {}
|
98 |
+
if args_dict['dataset'] == 'ontonotes':
|
99 |
+
data_path = 'data/Pre_MRL/onto_pos_ner_dp'
|
100 |
+
else:
|
101 |
+
data_path = 'data/Prep_MRL/ud_pos_ner_dp'
|
102 |
+
for split in args_dict['splits']:
|
103 |
+
args_dict['data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain']
|
104 |
+
###################################
|
105 |
+
args_dict['data_paths']['poetry'] = data_path + '_' + 'test' + '_' + args_dict['domain']
|
106 |
+
args_dict['data_paths']['prose'] = data_path + '_' + 'test' + '_' + args_dict['domain']
|
107 |
+
###################################
|
108 |
+
args_dict['alphabet_data_paths'] = {}
|
109 |
+
for split in args_dict['splits']:
|
110 |
+
if args_dict['dataset'] == 'ontonotes':
|
111 |
+
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + 'all'
|
112 |
+
else:
|
113 |
+
if '_' in args_dict['domain']:
|
114 |
+
args_dict['alphabet_data_paths'][split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
|
115 |
+
else:
|
116 |
+
args_dict['alphabet_data_paths'][split] = args_dict['data_paths'][split]
|
117 |
+
args_dict['model_name'] = 'domain_' + args_dict['domain']
|
118 |
+
args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
|
119 |
+
args_dict['load_path'] = args.load_path
|
120 |
+
args_dict['load_sequence_taggers_paths'] = args.load_sequence_taggers_paths
|
121 |
+
if args_dict['load_sequence_taggers_paths'] is not None:
|
122 |
+
args_dict['gating'] = True
|
123 |
+
args_dict['num_gates'] = len(args_dict['load_sequence_taggers_paths']) + 1
|
124 |
+
else:
|
125 |
+
if not args_dict['gating']:
|
126 |
+
args_dict['num_gates'] = 0
|
127 |
+
args_dict['strict'] = args.strict
|
128 |
+
args_dict['num_epochs'] = args.num_epochs
|
129 |
+
args_dict['batch_size'] = args.batch_size
|
130 |
+
args_dict['hidden_size'] = args.hidden_size
|
131 |
+
args_dict['arc_space'] = args.arc_space
|
132 |
+
args_dict['arc_tag_space'] = args.arc_tag_space
|
133 |
+
args_dict['num_layers'] = args.num_layers
|
134 |
+
args_dict['num_filters'] = args.num_filters
|
135 |
+
args_dict['kernel_size'] = args.kernel_size
|
136 |
+
args_dict['learning_rate'] = args.learning_rate
|
137 |
+
args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
|
138 |
+
args_dict['opt'] = args.opt
|
139 |
+
args_dict['momentum'] = args.momentum
|
140 |
+
args_dict['betas'] = tuple(args.betas)
|
141 |
+
args_dict['epsilon'] = args.epsilon
|
142 |
+
args_dict['decay_rate'] = args.decay_rate
|
143 |
+
args_dict['clip'] = args.clip
|
144 |
+
args_dict['gamma'] = args.gamma
|
145 |
+
args_dict['schedule'] = args.schedule
|
146 |
+
args_dict['p_rnn'] = tuple(args.p_rnn)
|
147 |
+
args_dict['p_in'] = args.p_in
|
148 |
+
args_dict['p_out'] = args.p_out
|
149 |
+
args_dict['unk_replace'] = args.unk_replace
|
150 |
+
args_dict['set_num_training_samples'] = args.set_num_training_samples
|
151 |
+
args_dict['punct_set'] = None
|
152 |
+
if args.punct_set is not None:
|
153 |
+
args_dict['punct_set'] = set(args.punct_set)
|
154 |
+
logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
|
155 |
+
args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
|
156 |
+
args_dict['freeze_sequence_taggers'] = args.freeze_sequence_taggers
|
157 |
+
args_dict['word_embedding'] = args.word_embedding
|
158 |
+
args_dict['word_path'] = args.word_path
|
159 |
+
args_dict['use_char'] = args.use_char
|
160 |
+
args_dict['char_embedding'] = args.char_embedding
|
161 |
+
args_dict['char_path'] = args.char_path
|
162 |
+
args_dict['pos_embedding'] = args.pos_embedding
|
163 |
+
args_dict['pos_path'] = args.pos_path
|
164 |
+
args_dict['use_pos'] = args.use_pos
|
165 |
+
args_dict['pos_dim'] = args.pos_dim
|
166 |
+
args_dict['word_dict'] = None
|
167 |
+
args_dict['word_dim'] = args.word_dim
|
168 |
+
if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
|
169 |
+
args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
|
170 |
+
args_dict['word_path'])
|
171 |
+
args_dict['char_dict'] = None
|
172 |
+
args_dict['char_dim'] = args.char_dim
|
173 |
+
if args_dict['char_embedding'] != 'random':
|
174 |
+
args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
|
175 |
+
args_dict['char_path'])
|
176 |
+
args_dict['pos_dict'] = None
|
177 |
+
if args_dict['pos_embedding'] != 'random':
|
178 |
+
args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
|
179 |
+
args_dict['pos_path'])
|
180 |
+
args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
|
181 |
+
args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
|
182 |
+
args_dict['eval_mode'] = args.eval_mode
|
183 |
+
args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
184 |
+
args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
|
185 |
+
args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
|
186 |
+
args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
|
187 |
+
logger.info("Saving arguments to file")
|
188 |
+
save_args(args, args_dict['full_model_name'])
|
189 |
+
logger.info("Creating Alphabets")
|
190 |
+
alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_data_paths'], args_dict['word_dict'])
|
191 |
+
args_dict = {**args_dict, **alphabet_dict}
|
192 |
+
ARGS = namedtuple('ARGS', args_dict.keys())
|
193 |
+
my_args = ARGS(**args_dict)
|
194 |
+
return my_args
|
195 |
+
|
196 |
+
|
197 |
+
def creating_alphabets(alphabet_path, alphabet_data_paths, word_dict):
|
198 |
+
train_paths = alphabet_data_paths['train']
|
199 |
+
extra_paths = [v for k,v in alphabet_data_paths.items() if k != 'train']
|
200 |
+
alphabet_dict = {}
|
201 |
+
alphabet_dict['alphabets'] = prepare_data.create_alphabets(alphabet_path,
|
202 |
+
train_paths,
|
203 |
+
extra_paths=extra_paths,
|
204 |
+
max_vocabulary_size=100000,
|
205 |
+
embedd_dict=word_dict)
|
206 |
+
for k, v in alphabet_dict['alphabets'].items():
|
207 |
+
num_key = 'num_' + k.split('_')[0]
|
208 |
+
alphabet_dict[num_key] = v.size()
|
209 |
+
logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
|
210 |
+
return alphabet_dict
|
211 |
+
|
212 |
+
def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
|
213 |
+
if tokens_dict is None:
|
214 |
+
return None
|
215 |
+
scale = np.sqrt(3.0 / dim)
|
216 |
+
table = np.empty([alphabet.size(), dim], dtype=np.float32)
|
217 |
+
table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
218 |
+
oov_tokens = 0
|
219 |
+
for token, index in alphabet.items():
|
220 |
+
if token in ['aTA']:
|
221 |
+
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
222 |
+
oov_tokens += 1
|
223 |
+
elif token in tokens_dict:
|
224 |
+
embedding = tokens_dict[token]
|
225 |
+
elif token.lower() in tokens_dict:
|
226 |
+
embedding = tokens_dict[token.lower()]
|
227 |
+
else:
|
228 |
+
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
229 |
+
oov_tokens += 1
|
230 |
+
# print(token)
|
231 |
+
table[index, :] = embedding
|
232 |
+
print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
|
233 |
+
table = torch.from_numpy(table)
|
234 |
+
return table
|
235 |
+
|
236 |
+
def save_args(args, full_model_name):
|
237 |
+
arg_path = full_model_name + '.arg.json'
|
238 |
+
argparse_dict = vars(args)
|
239 |
+
with open(arg_path, 'w') as f:
|
240 |
+
json.dump(argparse_dict, f)
|
241 |
+
|
242 |
+
def generate_optimizer(args, lr, params):
|
243 |
+
params = filter(lambda param: param.requires_grad, params)
|
244 |
+
if args.opt == 'adam':
|
245 |
+
return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
|
246 |
+
elif args.opt == 'sgd':
|
247 |
+
return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
|
248 |
+
else:
|
249 |
+
raise ValueError('Unknown optimization algorithm: %s' % args.opt)
|
250 |
+
|
251 |
+
|
252 |
+
def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
|
253 |
+
path_name = full_model_name + '.pt'
|
254 |
+
print('Saving model to: %s' % path_name)
|
255 |
+
state = {'model_state_dict': model.state_dict(),
|
256 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
257 |
+
'opt': opt,
|
258 |
+
'dev_eval_dict': dev_eval_dict,
|
259 |
+
'test_eval_dict': test_eval_dict}
|
260 |
+
torch.save(state, path_name)
|
261 |
+
|
262 |
+
|
263 |
+
def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
|
264 |
+
print('Loading saved model from: %s' % load_path)
|
265 |
+
checkpoint = torch.load(load_path, map_location=args.device)
|
266 |
+
if checkpoint['opt'] != args.opt:
|
267 |
+
raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
|
268 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
|
269 |
+
|
270 |
+
if strict:
|
271 |
+
generate_optimizer(args, args.learning_rate, model.parameters())
|
272 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
273 |
+
for state in optimizer.state.values():
|
274 |
+
for k, v in state.items():
|
275 |
+
if isinstance(v, torch.Tensor):
|
276 |
+
state[k] = v.to(args.device)
|
277 |
+
dev_eval_dict = checkpoint['dev_eval_dict']
|
278 |
+
test_eval_dict = checkpoint['test_eval_dict']
|
279 |
+
start_epoch = dev_eval_dict['in_domain']['epoch']
|
280 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
281 |
+
|
282 |
+
|
283 |
+
def build_model_and_optimizer(args):
|
284 |
+
word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
|
285 |
+
char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
|
286 |
+
pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
|
287 |
+
model = BiAffine_Parser_Gated(args.word_dim, args.num_word, args.char_dim, args.num_char,
|
288 |
+
args.use_pos, args.use_char, args.pos_dim, args.num_pos,
|
289 |
+
args.num_filters, args.kernel_size, args.rnn_mode,
|
290 |
+
args.hidden_size, args.num_layers, args.num_arc,
|
291 |
+
args.arc_space, args.arc_tag_space, args.num_gates,
|
292 |
+
embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
|
293 |
+
p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
|
294 |
+
biaffine=True, arc_decode=args.arc_decode, initializer=args.initializer)
|
295 |
+
print(model)
|
296 |
+
optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
|
297 |
+
start_epoch = 0
|
298 |
+
dev_eval_dict = {'in_domain': initialize_eval_dict()}
|
299 |
+
test_eval_dict = {'in_domain': initialize_eval_dict()}
|
300 |
+
if args.load_path:
|
301 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
|
302 |
+
load_checkpoint(args, model, optimizer,
|
303 |
+
dev_eval_dict, test_eval_dict,
|
304 |
+
start_epoch, args.load_path, strict=args.strict)
|
305 |
+
if args.load_sequence_taggers_paths:
|
306 |
+
pretrained_dict = {}
|
307 |
+
model_dict = model.state_dict()
|
308 |
+
for idx, path in enumerate(args.load_sequence_taggers_paths):
|
309 |
+
print('Loading saved sequence_tagger from: %s' % path)
|
310 |
+
checkpoint = torch.load(path, map_location=args.device)
|
311 |
+
for k, v in checkpoint['model_state_dict'].items():
|
312 |
+
if 'rnn_encoder.' in k:
|
313 |
+
pretrained_dict['extra_rnn_encoders.' + str(idx) + '.' + k.replace('rnn_encoder.', '')] = v
|
314 |
+
model_dict.update(pretrained_dict)
|
315 |
+
model.load_state_dict(model_dict)
|
316 |
+
if args.freeze_sequence_taggers:
|
317 |
+
print('Freezing Classifiers')
|
318 |
+
for name, parameter in model.named_parameters():
|
319 |
+
if 'extra_rnn_encoders' in name:
|
320 |
+
parameter.requires_grad = False
|
321 |
+
if args.freeze_word_embeddings:
|
322 |
+
model.rnn_encoder.word_embedd.weight.requires_grad = False
|
323 |
+
# model.rnn_encoder.char_embedd.weight.requires_grad = False
|
324 |
+
# model.rnn_encoder.pos_embedd.weight.requires_grad = False
|
325 |
+
device = args.device
|
326 |
+
model.to(device)
|
327 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
328 |
+
|
329 |
+
|
330 |
+
def initialize_eval_dict():
|
331 |
+
eval_dict = {}
|
332 |
+
eval_dict['dp_uas'] = 0.0
|
333 |
+
eval_dict['dp_las'] = 0.0
|
334 |
+
eval_dict['epoch'] = 0
|
335 |
+
eval_dict['dp_ucorrect'] = 0.0
|
336 |
+
eval_dict['dp_lcorrect'] = 0.0
|
337 |
+
eval_dict['dp_total'] = 0.0
|
338 |
+
eval_dict['dp_ucomplete_match'] = 0.0
|
339 |
+
eval_dict['dp_lcomplete_match'] = 0.0
|
340 |
+
eval_dict['dp_ucorrect_nopunc'] = 0.0
|
341 |
+
eval_dict['dp_lcorrect_nopunc'] = 0.0
|
342 |
+
eval_dict['dp_total_nopunc'] = 0.0
|
343 |
+
eval_dict['dp_ucomplete_match_nopunc'] = 0.0
|
344 |
+
eval_dict['dp_lcomplete_match_nopunc'] = 0.0
|
345 |
+
eval_dict['dp_root_correct'] = 0.0
|
346 |
+
eval_dict['dp_total_root'] = 0.0
|
347 |
+
eval_dict['dp_total_inst'] = 0.0
|
348 |
+
eval_dict['dp_total'] = 0.0
|
349 |
+
eval_dict['dp_total_inst'] = 0.0
|
350 |
+
eval_dict['dp_total_nopunc'] = 0.0
|
351 |
+
eval_dict['dp_total_root'] = 0.0
|
352 |
+
return eval_dict
|
353 |
+
|
354 |
+
def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
|
355 |
+
best_model, best_optimizer, patient):
|
356 |
+
# In-domain evaluation
|
357 |
+
curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
|
358 |
+
is_best_in_domain = dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] <= curr_dev_eval_dict['dp_lcorrect_nopunc'] or \
|
359 |
+
(dev_eval_dict['in_domain']['dp_lcorrect_nopunc'] == curr_dev_eval_dict['dp_lcorrect_nopunc'] and
|
360 |
+
dev_eval_dict['in_domain']['dp_ucorrect_nopunc'] <= curr_dev_eval_dict['dp_ucorrect_nopunc'])
|
361 |
+
|
362 |
+
if is_best_in_domain:
|
363 |
+
for key, value in curr_dev_eval_dict.items():
|
364 |
+
dev_eval_dict['in_domain'][key] = value
|
365 |
+
curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
|
366 |
+
for key, value in curr_test_eval_dict.items():
|
367 |
+
test_eval_dict['in_domain'][key] = value
|
368 |
+
best_model = deepcopy(model)
|
369 |
+
best_optimizer = deepcopy(optimizer)
|
370 |
+
patient = 0
|
371 |
+
else:
|
372 |
+
patient += 1
|
373 |
+
if epoch == args.num_epochs:
|
374 |
+
# save in-domain checkpoint
|
375 |
+
if args.set_num_training_samples is not None:
|
376 |
+
splits_to_write = datasets.keys()
|
377 |
+
else:
|
378 |
+
splits_to_write = ['dev', 'test']
|
379 |
+
for split in splits_to_write:
|
380 |
+
if split == 'dev':
|
381 |
+
eval_dict = dev_eval_dict['in_domain']
|
382 |
+
elif split == 'test':
|
383 |
+
eval_dict = test_eval_dict['in_domain']
|
384 |
+
else:
|
385 |
+
eval_dict = None
|
386 |
+
write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
|
387 |
+
print("Saving best model")
|
388 |
+
save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
|
389 |
+
|
390 |
+
print('\n')
|
391 |
+
return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient
|
392 |
+
|
393 |
+
|
394 |
+
def evaluation(args, data, split, model, domain, epoch, str_res='results'):
|
395 |
+
# evaluate performance on data
|
396 |
+
model.eval()
|
397 |
+
|
398 |
+
eval_dict = initialize_eval_dict()
|
399 |
+
eval_dict['epoch'] = epoch
|
400 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
401 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
402 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
403 |
+
heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos, ner,out_arc, out_arc_tag, mask=masks, length=lengths,
|
404 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
405 |
+
lengths = lengths.cpu().numpy()
|
406 |
+
word = word.data.cpu().numpy()
|
407 |
+
pos = pos.data.cpu().numpy()
|
408 |
+
ner = ner.data.cpu().numpy()
|
409 |
+
heads = heads.data.cpu().numpy()
|
410 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
411 |
+
heads_pred = heads_pred.data.cpu().numpy()
|
412 |
+
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
|
413 |
+
stats, stats_nopunc, stats_root, num_inst = parse.eval_(word, pos, heads_pred, arc_tags_pred, heads,
|
414 |
+
arc_tags, args.alphabets['word_alphabet'], args.alphabets['pos_alphabet'],
|
415 |
+
lengths, punct_set=args.punct_set, symbolic_root=True)
|
416 |
+
ucorr, lcorr, total, ucm, lcm = stats
|
417 |
+
ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
|
418 |
+
corr_root, total_root = stats_root
|
419 |
+
eval_dict['dp_ucorrect'] += ucorr
|
420 |
+
eval_dict['dp_lcorrect'] += lcorr
|
421 |
+
eval_dict['dp_total'] += total
|
422 |
+
eval_dict['dp_ucomplete_match'] += ucm
|
423 |
+
eval_dict['dp_lcomplete_match'] += lcm
|
424 |
+
eval_dict['dp_ucorrect_nopunc'] += ucorr_nopunc
|
425 |
+
eval_dict['dp_lcorrect_nopunc'] += lcorr_nopunc
|
426 |
+
eval_dict['dp_total_nopunc'] += total_nopunc
|
427 |
+
eval_dict['dp_ucomplete_match_nopunc'] += ucm_nopunc
|
428 |
+
eval_dict['dp_lcomplete_match_nopunc'] += lcm_nopunc
|
429 |
+
eval_dict['dp_root_correct'] += corr_root
|
430 |
+
eval_dict['dp_total_root'] += total_root
|
431 |
+
eval_dict['dp_total_inst'] += num_inst
|
432 |
+
|
433 |
+
eval_dict['dp_uas'] = eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
|
434 |
+
eval_dict['dp_las'] = eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'] # considering w. punctuation
|
435 |
+
print_results(eval_dict, split, domain, str_res)
|
436 |
+
return eval_dict
|
437 |
+
|
438 |
+
|
439 |
+
def print_results(eval_dict, split, domain, str_res='results'):
|
440 |
+
print('----------------------------------------------------------------------------------------------------------------------------')
|
441 |
+
print('Testing model on domain %s' % domain)
|
442 |
+
print('--------------- Dependency Parsing - %s ---------------' % split)
|
443 |
+
print(
|
444 |
+
str_res + ' on ' + split + ' W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
|
445 |
+
eval_dict['dp_ucorrect'], eval_dict['dp_lcorrect'], eval_dict['dp_total'],
|
446 |
+
eval_dict['dp_ucorrect'] * 100 / eval_dict['dp_total'],
|
447 |
+
eval_dict['dp_lcorrect'] * 100 / eval_dict['dp_total'],
|
448 |
+
eval_dict['dp_ucomplete_match'] * 100 / eval_dict['dp_total_inst'],
|
449 |
+
eval_dict['dp_lcomplete_match'] * 100 / eval_dict['dp_total_inst'],
|
450 |
+
eval_dict['epoch']))
|
451 |
+
print(
|
452 |
+
str_res + ' on ' + split + ' Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
|
453 |
+
eval_dict['dp_ucorrect_nopunc'], eval_dict['dp_lcorrect_nopunc'], eval_dict['dp_total_nopunc'],
|
454 |
+
eval_dict['dp_ucorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
|
455 |
+
eval_dict['dp_lcorrect_nopunc'] * 100 / eval_dict['dp_total_nopunc'],
|
456 |
+
eval_dict['dp_ucomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
|
457 |
+
eval_dict['dp_lcomplete_match_nopunc'] * 100 / eval_dict['dp_total_inst'],
|
458 |
+
eval_dict['epoch']))
|
459 |
+
print(str_res + ' on ' + split + ' Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
|
460 |
+
eval_dict['dp_root_correct'], eval_dict['dp_total_root'],
|
461 |
+
eval_dict['dp_root_correct'] * 100 / eval_dict['dp_total_root'], eval_dict['epoch']))
|
462 |
+
print('\n')
|
463 |
+
|
464 |
+
def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
|
465 |
+
str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
|
466 |
+
res_filename = str_file + '_res.txt'
|
467 |
+
pred_filename = str_file + '_pred.txt'
|
468 |
+
gold_filename = str_file + '_gold.txt'
|
469 |
+
if eval_dict is not None:
|
470 |
+
# save results dictionary into a file
|
471 |
+
with open(res_filename, 'w') as f:
|
472 |
+
json.dump(eval_dict, f)
|
473 |
+
|
474 |
+
# save predictions and gold labels into files
|
475 |
+
pred_writer = Writer(args.alphabets)
|
476 |
+
gold_writer = Writer(args.alphabets)
|
477 |
+
pred_writer.start(pred_filename)
|
478 |
+
gold_writer.start(gold_filename)
|
479 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
480 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
481 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
482 |
+
heads_pred, arc_tags_pred, _ = model.decode(args.model_path,word, pos,ner,out_arc, out_arc_tag, mask=masks, length=lengths,
|
483 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
484 |
+
lengths = lengths.cpu().numpy()
|
485 |
+
word = word.data.cpu().numpy()
|
486 |
+
pos = pos.data.cpu().numpy()
|
487 |
+
ner = ner.data.cpu().numpy()
|
488 |
+
heads = heads.data.cpu().numpy()
|
489 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
490 |
+
heads_pred = heads_pred.data.cpu().numpy()
|
491 |
+
arc_tags_pred = arc_tags_pred.data.cpu().numpy()
|
492 |
+
# writing predictions
|
493 |
+
pred_writer.write(word, pos, ner, heads_pred, arc_tags_pred, lengths, symbolic_root=True)
|
494 |
+
# writing gold labels
|
495 |
+
gold_writer.write(word, pos, ner, heads, arc_tags, lengths, symbolic_root=True)
|
496 |
+
|
497 |
+
pred_writer.close()
|
498 |
+
gold_writer.close()
|
499 |
+
|
500 |
+
def main():
|
501 |
+
logger.info("Reading and creating arguments")
|
502 |
+
args = read_arguments()
|
503 |
+
logger.info("Reading Data")
|
504 |
+
datasets = {}
|
505 |
+
for split in args.splits:
|
506 |
+
print("Splits are:",split)
|
507 |
+
dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
|
508 |
+
symbolic_root=True)
|
509 |
+
datasets[split] = dataset
|
510 |
+
if args.set_num_training_samples is not None:
|
511 |
+
print('Setting train and dev to %d samples' % args.set_num_training_samples)
|
512 |
+
datasets = rearrange_splits.rearranging_splits(datasets, args.set_num_training_samples)
|
513 |
+
logger.info("Creating Networks")
|
514 |
+
num_data = sum(datasets['train'][1])
|
515 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
|
516 |
+
best_model = deepcopy(model)
|
517 |
+
best_optimizer = deepcopy(optimizer)
|
518 |
+
|
519 |
+
logger.info('Training INFO of in domain %s' % args.domain)
|
520 |
+
logger.info('Training on Dependecy Parsing')
|
521 |
+
logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
|
522 |
+
logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
|
523 |
+
logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
|
524 |
+
logger.info("num_epochs: %d" % (args.num_epochs))
|
525 |
+
print('\n')
|
526 |
+
|
527 |
+
if not args.eval_mode:
|
528 |
+
logger.info("Training")
|
529 |
+
num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
|
530 |
+
lr = args.learning_rate
|
531 |
+
patient = 0
|
532 |
+
decay = 0
|
533 |
+
for epoch in range(start_epoch + 1, args.num_epochs + 1):
|
534 |
+
print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
|
535 |
+
epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
|
536 |
+
model.train()
|
537 |
+
total_loss = 0.0
|
538 |
+
total_arc_loss = 0.0
|
539 |
+
total_arc_tag_loss = 0.0
|
540 |
+
total_train_inst = 0.0
|
541 |
+
|
542 |
+
train_iter = prepare_data.iterate_batch_rand_bucket_choosing(
|
543 |
+
datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
|
544 |
+
start_time = time.time()
|
545 |
+
batch_num = 0
|
546 |
+
for batch_num, batch in enumerate(train_iter):
|
547 |
+
batch_num = batch_num + 1
|
548 |
+
optimizer.zero_grad()
|
549 |
+
# compute loss of main task
|
550 |
+
word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
|
551 |
+
out_arc, out_arc_tag, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
552 |
+
loss_arc, loss_arc_tag = model.loss(out_arc, out_arc_tag, heads, arc_tags, mask=masks, length=lengths)
|
553 |
+
loss = loss_arc + loss_arc_tag
|
554 |
+
|
555 |
+
# update losses
|
556 |
+
num_insts = masks.data.sum() - word.size(0)
|
557 |
+
total_arc_loss += loss_arc.item() * num_insts
|
558 |
+
total_arc_tag_loss += loss_arc_tag.item() * num_insts
|
559 |
+
total_loss += loss.item() * num_insts
|
560 |
+
total_train_inst += num_insts
|
561 |
+
# optimize parameters
|
562 |
+
loss.backward()
|
563 |
+
clip_grad_norm_(model.parameters(), args.clip)
|
564 |
+
optimizer.step()
|
565 |
+
|
566 |
+
time_ave = (time.time() - start_time) / batch_num
|
567 |
+
time_left = (num_batches - batch_num) * time_ave
|
568 |
+
|
569 |
+
# update log
|
570 |
+
if batch_num % 50 == 0:
|
571 |
+
log_info = 'train: %d/%d, domain: %s, total loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time left: %.2fs' % \
|
572 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
|
573 |
+
total_arc_tag_loss / total_train_inst, time_left)
|
574 |
+
sys.stdout.write(log_info)
|
575 |
+
sys.stdout.write('\n')
|
576 |
+
sys.stdout.flush()
|
577 |
+
print('\n')
|
578 |
+
print('train: %d/%d, domain: %s, total_loss: %.2f, arc_loss: %.2f, arc_tag_loss: %.2f, time: %.2fs' %
|
579 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, total_arc_loss / total_train_inst,
|
580 |
+
total_arc_tag_loss / total_train_inst, time.time() - start_time))
|
581 |
+
|
582 |
+
dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
|
583 |
+
if patient >= args.schedule:
|
584 |
+
lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
|
585 |
+
optimizer = generate_optimizer(args, lr, model.parameters())
|
586 |
+
print('updated learning rate to %.6f' % lr)
|
587 |
+
patient = 0
|
588 |
+
print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
|
589 |
+
print('\n')
|
590 |
+
for split in datasets.keys():
|
591 |
+
eval_dict = evaluation(args, datasets[split], split, best_model, args.domain, epoch, 'best_results')
|
592 |
+
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
|
593 |
+
|
594 |
+
else:
|
595 |
+
logger.info("Evaluating")
|
596 |
+
epoch = start_epoch
|
597 |
+
for split in ['train', 'dev', 'test','poetry','prose']:
|
598 |
+
eval_dict = evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
|
599 |
+
write_results(args, datasets[split], args.domain, split, model, args.domain, eval_dict)
|
600 |
+
|
601 |
+
|
602 |
+
if __name__ == '__main__':
|
603 |
+
main()
|
examples/SequenceTagger.py
ADDED
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import sys
|
3 |
+
from os import path, makedirs, system, remove
|
4 |
+
|
5 |
+
sys.path.append(".")
|
6 |
+
sys.path.append("..")
|
7 |
+
|
8 |
+
import time
|
9 |
+
import argparse
|
10 |
+
import uuid
|
11 |
+
import json
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from collections import namedtuple
|
16 |
+
from copy import deepcopy
|
17 |
+
from torch.nn.utils import clip_grad_norm_
|
18 |
+
from torch.optim import Adam, SGD
|
19 |
+
from utils.io_ import seeds, Writer, get_logger, Index2Instance, prepare_data, write_extra_labels
|
20 |
+
from utils.models.sequence_tagger import Sequence_Tagger
|
21 |
+
from utils import load_word_embeddings
|
22 |
+
from utils.tasks.seqeval import accuracy_score, f1_score, precision_score, recall_score,classification_report
|
23 |
+
|
24 |
+
uid = uuid.uuid4().hex[:6]
|
25 |
+
|
26 |
+
logger = get_logger('SequenceTagger')
|
27 |
+
|
28 |
+
def read_arguments():
|
29 |
+
args_ = argparse.ArgumentParser(description='Sovling SequenceTagger')
|
30 |
+
args_.add_argument('--dataset', choices=['ontonotes', 'ud'], help='Dataset', required=True)
|
31 |
+
args_.add_argument('--domain', help='domain', required=True)
|
32 |
+
args_.add_argument('--rnn_mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn',
|
33 |
+
required=True)
|
34 |
+
args_.add_argument('--task', default='distance_from_the_root', choices=['distance_from_the_root', 'number_of_children',\
|
35 |
+
'relative_pos_based', 'language_model','add_label','add_head_coarse_pos','Multitask_POS_predict','Multitask_case_predict',\
|
36 |
+
'Multitask_label_predict','Multitask_coarse_predict','MRL_Person','MRL_Gender','MRL_case','MRL_POS','MRL_no','MRL_label',\
|
37 |
+
'predict_coarse_of_modifier','predict_ma_tag_of_modifier','add_head_ma','predict_case_of_modifier'], help='sequence_tagger task')
|
38 |
+
args_.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
|
39 |
+
args_.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
|
40 |
+
args_.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
|
41 |
+
args_.add_argument('--tag_space', type=int, default=128, help='Dimension of tag space')
|
42 |
+
args_.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
|
43 |
+
args_.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
|
44 |
+
args_.add_argument('--kernel_size', type=int, default=3, help='Size of Kernel for CNN')
|
45 |
+
args_.add_argument('--use_pos', action='store_true', help='use part-of-speech embedding.')
|
46 |
+
args_.add_argument('--use_char', action='store_true', help='use character embedding and CNN.')
|
47 |
+
args_.add_argument('--word_dim', type=int, default=300, help='Dimension of word embeddings')
|
48 |
+
args_.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
|
49 |
+
args_.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
|
50 |
+
args_.add_argument('--initializer', choices=['xavier'], help='initialize model parameters')
|
51 |
+
args_.add_argument('--opt', choices=['adam', 'sgd'], help='optimization algorithm')
|
52 |
+
args_.add_argument('--momentum', type=float, default=0.9, help='momentum of optimizer')
|
53 |
+
args_.add_argument('--betas', nargs=2, type=float, default=[0.9, 0.9], help='betas of optimizer')
|
54 |
+
args_.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
|
55 |
+
args_.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
|
56 |
+
args_.add_argument('--schedule', type=int, help='schedule for learning rate decay')
|
57 |
+
args_.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
|
58 |
+
args_.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
|
59 |
+
args_.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam')
|
60 |
+
args_.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
|
61 |
+
args_.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
|
62 |
+
args_.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
|
63 |
+
args_.add_argument('--unk_replace', type=float, default=0.,
|
64 |
+
help='The rate to replace a singleton word with UNK')
|
65 |
+
args_.add_argument('--punct_set', nargs='+', type=str, help='List of punctuations')
|
66 |
+
args_.add_argument('--word_embedding', choices=['random', 'glove', 'fasttext', 'word2vec'],
|
67 |
+
help='Embedding for words')
|
68 |
+
args_.add_argument('--word_path', help='path for word embedding dict - in case word_embedding is not random')
|
69 |
+
args_.add_argument('--freeze_word_embeddings', action='store_true', help='frozen the word embedding (disable fine-tuning).')
|
70 |
+
args_.add_argument('--char_embedding', choices=['random','hellwig'], help='Embedding for characters',
|
71 |
+
required=True)
|
72 |
+
args_.add_argument('--pos_embedding', choices=['random','one_hot'], help='Embedding for pos',
|
73 |
+
required=True)
|
74 |
+
args_.add_argument('--char_path', help='path for character embedding dict')
|
75 |
+
args_.add_argument('--pos_path', help='path for pos embedding dict')
|
76 |
+
args_.add_argument('--use_unlabeled_data', action='store_true', help='flag to use unlabeled data.')
|
77 |
+
args_.add_argument('--use_labeled_data', action='store_true', help='flag to use labeled data.')
|
78 |
+
args_.add_argument('--model_path', help='path for saving model file.', required=True)
|
79 |
+
args_.add_argument('--parser_path', help='path for loading parser files.', default=None)
|
80 |
+
args_.add_argument('--load_path', help='path for loading saved source model file.', default=None)
|
81 |
+
args_.add_argument('--strict',action='store_true', help='if True loaded model state should contain '
|
82 |
+
'exactly the same keys as current model')
|
83 |
+
args_.add_argument('--eval_mode', action='store_true', help='evaluating model without training it')
|
84 |
+
args = args_.parse_args()
|
85 |
+
args_dict = {}
|
86 |
+
args_dict['dataset'] = args.dataset
|
87 |
+
args_dict['domain'] = args.domain
|
88 |
+
args_dict['task'] = args.task
|
89 |
+
args_dict['rnn_mode'] = args.rnn_mode
|
90 |
+
args_dict['load_path'] = args.load_path
|
91 |
+
args_dict['strict'] = args.strict
|
92 |
+
args_dict['model_path'] = args.model_path
|
93 |
+
if not path.exists(args_dict['model_path']):
|
94 |
+
makedirs(args_dict['model_path'])
|
95 |
+
args_dict['parser_path'] = args.parser_path
|
96 |
+
args_dict['model_name'] = 'domain_' + args_dict['domain']
|
97 |
+
args_dict['full_model_name'] = path.join(args_dict['model_path'],args_dict['model_name'])
|
98 |
+
args_dict['use_unlabeled_data'] = args.use_unlabeled_data
|
99 |
+
args_dict['use_labeled_data'] = args.use_labeled_data
|
100 |
+
print(args_dict['parser_path'])
|
101 |
+
if args_dict['task'] == 'number_of_children':
|
102 |
+
args_dict['data_paths'] = write_extra_labels.add_number_of_children(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
103 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
104 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
105 |
+
elif args_dict['task'] == 'distance_from_the_root':
|
106 |
+
args_dict['data_paths'] = write_extra_labels.add_distance_from_the_root(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
107 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
108 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
109 |
+
elif args_dict['task'] == 'Multitask_label_predict':
|
110 |
+
args_dict['data_paths'] = write_extra_labels.Multitask_label_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
111 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
112 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
113 |
+
elif args_dict['task'] == 'Multitask_coarse_predict':
|
114 |
+
args_dict['data_paths'] = write_extra_labels.Multitask_coarse_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
115 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
116 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
117 |
+
elif args_dict['task'] == 'Multitask_POS_predict':
|
118 |
+
args_dict['data_paths'] = write_extra_labels.Multitask_POS_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
119 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
120 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
121 |
+
elif args_dict['task'] == 'relative_pos_based':
|
122 |
+
args_dict['data_paths'] = write_extra_labels.add_relative_pos_based(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
123 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
124 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
125 |
+
elif args_dict['task'] == 'add_label':
|
126 |
+
args_dict['data_paths'] = write_extra_labels.add_label(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
127 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
128 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
129 |
+
elif args_dict['task'] == 'add_relative_TAG':
|
130 |
+
args_dict['data_paths'] = write_extra_labels.add_relative_TAG(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
131 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
132 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
133 |
+
elif args_dict['task'] == 'add_head_coarse_pos':
|
134 |
+
args_dict['data_paths'] = write_extra_labels.add_head_coarse_pos(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
135 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
136 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
137 |
+
elif args_dict['task'] == 'predict_ma_tag_of_modifier':
|
138 |
+
args_dict['data_paths'] = write_extra_labels.predict_ma_tag_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
139 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
140 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
141 |
+
elif args_dict['task'] == 'Multitask_case_predict':
|
142 |
+
args_dict['data_paths'] = write_extra_labels.Multitask_case_predict(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
143 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
144 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
145 |
+
elif args_dict['task'] == 'predict_coarse_of_modifier':
|
146 |
+
args_dict['data_paths'] = write_extra_labels.predict_coarse_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
147 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
148 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
149 |
+
elif args_dict['task'] == 'predict_case_of_modifier':
|
150 |
+
args_dict['data_paths'] = write_extra_labels.predict_case_of_modifier(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
151 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
152 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
153 |
+
elif args_dict['task'] == 'add_head_ma':
|
154 |
+
args_dict['data_paths'] = write_extra_labels.add_head_ma(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
155 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
156 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
157 |
+
elif args_dict['task'] == 'MRL_case':
|
158 |
+
args_dict['data_paths'] = write_extra_labels.MRL_case(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
159 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
160 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
161 |
+
elif args_dict['task'] == 'MRL_POS':
|
162 |
+
args_dict['data_paths'] = write_extra_labels.MRL_POS(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
163 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
164 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
165 |
+
elif args_dict['task'] == 'MRL_no':
|
166 |
+
args_dict['data_paths'] = write_extra_labels.MRL_no(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
167 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
168 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
169 |
+
elif args_dict['task'] == 'MRL_label':
|
170 |
+
args_dict['data_paths'] = write_extra_labels.MRL_label(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
171 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
172 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
173 |
+
elif args_dict['task'] == 'MRL_Person':
|
174 |
+
args_dict['data_paths'] = write_extra_labels.MRL_Person(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
175 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
176 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
177 |
+
elif args_dict['task'] == 'MRL_Gender':
|
178 |
+
args_dict['data_paths'] = write_extra_labels.MRL_Gender(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
179 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
180 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
181 |
+
else: #args_dict['task'] == 'language_model':
|
182 |
+
args_dict['data_paths'] = write_extra_labels.add_language_model(args_dict['model_path'], args_dict['parser_path'], args_dict['domain'], args_dict['domain'],
|
183 |
+
use_unlabeled_data=args_dict['use_unlabeled_data'],
|
184 |
+
use_labeled_data=args_dict['use_labeled_data'])
|
185 |
+
args_dict['splits'] = args_dict['data_paths'].keys()
|
186 |
+
alphabet_data_paths = deepcopy(args_dict['data_paths'])
|
187 |
+
if args_dict['dataset'] == 'ontonotes':
|
188 |
+
data_path = 'data/onto_pos_ner_dp'
|
189 |
+
else:
|
190 |
+
data_path = 'data/ud_pos_ner_dp'
|
191 |
+
# Adding more resources to make sure equal alphabet size for all domains
|
192 |
+
for split in args_dict['splits']:
|
193 |
+
if args_dict['dataset'] == 'ontonotes':
|
194 |
+
alphabet_data_paths['additional_' + split] = data_path + '_' + split + '_' + 'all'
|
195 |
+
else:
|
196 |
+
if '_' in args_dict['domain']:
|
197 |
+
alphabet_data_paths[split] = data_path + '_' + split + '_' + args_dict['domain'].split('_')[0]
|
198 |
+
else:
|
199 |
+
alphabet_data_paths[split] = args_dict['data_paths'][split]
|
200 |
+
args_dict['alphabet_data_paths'] = alphabet_data_paths
|
201 |
+
args_dict['num_epochs'] = args.num_epochs
|
202 |
+
args_dict['batch_size'] = args.batch_size
|
203 |
+
args_dict['hidden_size'] = args.hidden_size
|
204 |
+
args_dict['tag_space'] = args.tag_space
|
205 |
+
args_dict['num_layers'] = args.num_layers
|
206 |
+
args_dict['num_filters'] = args.num_filters
|
207 |
+
args_dict['kernel_size'] = args.kernel_size
|
208 |
+
args_dict['learning_rate'] = args.learning_rate
|
209 |
+
args_dict['initializer'] = nn.init.xavier_uniform_ if args.initializer == 'xavier' else None
|
210 |
+
args_dict['opt'] = args.opt
|
211 |
+
args_dict['momentum'] = args.momentum
|
212 |
+
args_dict['betas'] = tuple(args.betas)
|
213 |
+
args_dict['epsilon'] = args.epsilon
|
214 |
+
args_dict['decay_rate'] = args.decay_rate
|
215 |
+
args_dict['clip'] = args.clip
|
216 |
+
args_dict['gamma'] = args.gamma
|
217 |
+
args_dict['schedule'] = args.schedule
|
218 |
+
args_dict['p_rnn'] = tuple(args.p_rnn)
|
219 |
+
args_dict['p_in'] = args.p_in
|
220 |
+
args_dict['p_out'] = args.p_out
|
221 |
+
args_dict['unk_replace'] = args.unk_replace
|
222 |
+
args_dict['punct_set'] = None
|
223 |
+
if args.punct_set is not None:
|
224 |
+
args_dict['punct_set'] = set(args.punct_set)
|
225 |
+
logger.info("punctuations(%d): %s" % (len(args_dict['punct_set']), ' '.join(args_dict['punct_set'])))
|
226 |
+
args_dict['freeze_word_embeddings'] = args.freeze_word_embeddings
|
227 |
+
args_dict['word_embedding'] = args.word_embedding
|
228 |
+
args_dict['word_path'] = args.word_path
|
229 |
+
args_dict['use_char'] = args.use_char
|
230 |
+
args_dict['char_embedding'] = args.char_embedding
|
231 |
+
args_dict['pos_embedding'] = args.pos_embedding
|
232 |
+
args_dict['char_path'] = args.char_path
|
233 |
+
args_dict['pos_path'] = args.pos_path
|
234 |
+
args_dict['use_pos'] = args.use_pos
|
235 |
+
args_dict['pos_dim'] = args.pos_dim
|
236 |
+
args_dict['word_dict'] = None
|
237 |
+
args_dict['word_dim'] = args.word_dim
|
238 |
+
if args_dict['word_embedding'] != 'random' and args_dict['word_path']:
|
239 |
+
args_dict['word_dict'], args_dict['word_dim'] = load_word_embeddings.load_embedding_dict(args_dict['word_embedding'],
|
240 |
+
args_dict['word_path'])
|
241 |
+
args_dict['char_dict'] = None
|
242 |
+
args_dict['char_dim'] = args.char_dim
|
243 |
+
if args_dict['char_embedding'] != 'random':
|
244 |
+
args_dict['char_dict'], args_dict['char_dim'] = load_word_embeddings.load_embedding_dict(args_dict['char_embedding'],
|
245 |
+
args_dict['char_path'])
|
246 |
+
args_dict['pos_dict'] = None
|
247 |
+
if args_dict['pos_embedding'] != 'random':
|
248 |
+
args_dict['pos_dict'], args_dict['pos_dim'] = load_word_embeddings.load_embedding_dict(args_dict['pos_embedding'],
|
249 |
+
args_dict['pos_path'])
|
250 |
+
args_dict['alphabet_path'] = path.join(args_dict['model_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
|
251 |
+
args_dict['alphabet_parser_path'] = path.join(args_dict['parser_path'], 'alphabets' + '_src_domain_' + args_dict['domain'] + '/')
|
252 |
+
args_dict['model_name'] = path.join(args_dict['model_path'], args_dict['model_name'])
|
253 |
+
args_dict['eval_mode'] = args.eval_mode
|
254 |
+
args_dict['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
255 |
+
args_dict['word_status'] = 'frozen' if args.freeze_word_embeddings else 'fine tune'
|
256 |
+
args_dict['char_status'] = 'enabled' if args.use_char else 'disabled'
|
257 |
+
args_dict['pos_status'] = 'enabled' if args.use_pos else 'disabled'
|
258 |
+
logger.info("Saving arguments to file")
|
259 |
+
save_args(args, args_dict['full_model_name'])
|
260 |
+
logger.info("Creating Alphabets")
|
261 |
+
alphabet_dict = creating_alphabets(args_dict['alphabet_path'], args_dict['alphabet_parser_path'], args_dict['alphabet_data_paths'])
|
262 |
+
args_dict = {**args_dict, **alphabet_dict}
|
263 |
+
ARGS = namedtuple('ARGS', args_dict.keys())
|
264 |
+
my_args = ARGS(**args_dict)
|
265 |
+
return my_args
|
266 |
+
|
267 |
+
|
268 |
+
def creating_alphabets(alphabet_path, alphabet_parser_path, alphabet_data_paths):
|
269 |
+
data_paths_list = alphabet_data_paths.values()
|
270 |
+
alphabet_dict = {}
|
271 |
+
alphabet_dict['alphabets'] = prepare_data.create_alphabets_for_sequence_tagger(alphabet_path, alphabet_parser_path, data_paths_list)
|
272 |
+
for k, v in alphabet_dict['alphabets'].items():
|
273 |
+
num_key = 'num_' + k.split('_alphabet')[0]
|
274 |
+
alphabet_dict[num_key] = v.size()
|
275 |
+
logger.info("%s : %d" % (num_key, alphabet_dict[num_key]))
|
276 |
+
return alphabet_dict
|
277 |
+
|
278 |
+
def construct_embedding_table(alphabet, tokens_dict, dim, token_type='word'):
|
279 |
+
if tokens_dict is None:
|
280 |
+
return None
|
281 |
+
scale = np.sqrt(3.0 / dim)
|
282 |
+
table = np.empty([alphabet.size(), dim], dtype=np.float32)
|
283 |
+
table[prepare_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
284 |
+
oov_tokens = 0
|
285 |
+
for token, index in alphabet.items():
|
286 |
+
if token in tokens_dict:
|
287 |
+
embedding = tokens_dict[token]
|
288 |
+
if token =='ata':
|
289 |
+
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
290 |
+
else:
|
291 |
+
embedding = np.random.uniform(-scale, scale, [1, dim]).astype(np.float32)
|
292 |
+
oov_tokens += 1
|
293 |
+
table[index, :] = embedding
|
294 |
+
print('token type : %s, number of oov: %d' % (token_type, oov_tokens))
|
295 |
+
table = torch.from_numpy(table)
|
296 |
+
return table
|
297 |
+
|
298 |
+
def get_free_gpu():
|
299 |
+
system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > tmp.txt')
|
300 |
+
memory_available = [int(x.split()[2]) for x in open('tmp.txt', 'r').readlines()]
|
301 |
+
remove("tmp.txt")
|
302 |
+
free_device = 'cuda:' + str(np.argmax(memory_available))
|
303 |
+
return free_device
|
304 |
+
|
305 |
+
def save_args(args, full_model_name):
|
306 |
+
arg_path = full_model_name + '.arg.json'
|
307 |
+
argparse_dict = vars(args)
|
308 |
+
with open(arg_path, 'w') as f:
|
309 |
+
json.dump(argparse_dict, f)
|
310 |
+
|
311 |
+
def generate_optimizer(args, lr, params):
|
312 |
+
params = filter(lambda param: param.requires_grad, params)
|
313 |
+
if args.opt == 'adam':
|
314 |
+
return Adam(params, lr=lr, betas=args.betas, weight_decay=args.gamma, eps=args.epsilon)
|
315 |
+
elif args.opt == 'sgd':
|
316 |
+
return SGD(params, lr=lr, momentum=args.momentum, weight_decay=args.gamma, nesterov=True)
|
317 |
+
else:
|
318 |
+
raise ValueError('Unknown optimization algorithm: %s' % args.opt)
|
319 |
+
|
320 |
+
|
321 |
+
def save_checkpoint(model, optimizer, opt, dev_eval_dict, test_eval_dict, full_model_name):
|
322 |
+
path_name = full_model_name + '.pt'
|
323 |
+
print('Saving model to: %s' % path_name)
|
324 |
+
state = {'model_state_dict': model.state_dict(),
|
325 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
326 |
+
'opt': opt, 'dev_eval_dict': dev_eval_dict, 'test_eval_dict': test_eval_dict}
|
327 |
+
torch.save(state, path_name)
|
328 |
+
|
329 |
+
|
330 |
+
def load_checkpoint(args, model, optimizer, dev_eval_dict, test_eval_dict, start_epoch, load_path, strict=True):
|
331 |
+
print('Loading saved model from: %s' % load_path)
|
332 |
+
checkpoint = torch.load(load_path, map_location=args.device)
|
333 |
+
if checkpoint['opt'] != args.opt:
|
334 |
+
raise ValueError('loaded optimizer type is: %s instead of: %s' % (checkpoint['opt'], args.opt))
|
335 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
|
336 |
+
if strict:
|
337 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
338 |
+
dev_eval_dict = checkpoint['dev_eval_dict']
|
339 |
+
test_eval_dict = checkpoint['test_eval_dict']
|
340 |
+
start_epoch = dev_eval_dict['in_domain']['epoch']
|
341 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
342 |
+
|
343 |
+
|
344 |
+
def build_model_and_optimizer(args):
|
345 |
+
word_table = construct_embedding_table(args.alphabets['word_alphabet'], args.word_dict, args.word_dim, token_type='word')
|
346 |
+
char_table = construct_embedding_table(args.alphabets['char_alphabet'], args.char_dict, args.char_dim, token_type='char')
|
347 |
+
pos_table = construct_embedding_table(args.alphabets['pos_alphabet'], args.pos_dict, args.pos_dim, token_type='pos')
|
348 |
+
model = Sequence_Tagger(args.word_dim, args.num_word, args.char_dim, args.num_char,
|
349 |
+
args.use_pos, args.use_char, args.pos_dim, args.num_pos,
|
350 |
+
args.num_filters, args.kernel_size, args.rnn_mode,
|
351 |
+
args.hidden_size, args.num_layers, args.tag_space, args.num_auto_label,
|
352 |
+
embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table,
|
353 |
+
p_in=args.p_in, p_out=args.p_out, p_rnn=args.p_rnn,
|
354 |
+
initializer=args.initializer)
|
355 |
+
optimizer = generate_optimizer(args, args.learning_rate, model.parameters())
|
356 |
+
start_epoch = 0
|
357 |
+
dev_eval_dict = {'in_domain': initialize_eval_dict()}
|
358 |
+
test_eval_dict = {'in_domain': initialize_eval_dict()}
|
359 |
+
if args.load_path:
|
360 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = \
|
361 |
+
load_checkpoint(args, model, optimizer,
|
362 |
+
dev_eval_dict, test_eval_dict,
|
363 |
+
start_epoch, args.load_path, strict=args.strict)
|
364 |
+
if args.freeze_word_embeddings:
|
365 |
+
model.rnn_encoder.word_embedd.weight.requires_grad = False
|
366 |
+
# model.rnn_encoder.char_embedd.weight.requires_grad = False
|
367 |
+
# model.rnn_encoder.pos_embedd.weight.requires_grad = False
|
368 |
+
device = args.device
|
369 |
+
model.to(device)
|
370 |
+
return model, optimizer, dev_eval_dict, test_eval_dict, start_epoch
|
371 |
+
|
372 |
+
|
373 |
+
def initialize_eval_dict():
|
374 |
+
eval_dict = {}
|
375 |
+
eval_dict['auto_label_accuracy'] = 0.0
|
376 |
+
eval_dict['auto_label_precision'] = 0.0
|
377 |
+
eval_dict['auto_label_recall'] = 0.0
|
378 |
+
eval_dict['auto_label_f1'] = 0.0
|
379 |
+
return eval_dict
|
380 |
+
|
381 |
+
def in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch,
|
382 |
+
best_model, best_optimizer, patient):
|
383 |
+
# In-domain evaluation
|
384 |
+
curr_dev_eval_dict = evaluation(args, datasets['dev'], 'dev', model, args.domain, epoch, 'current_results')
|
385 |
+
is_best_in_domain = dev_eval_dict['in_domain']['auto_label_f1'] <= curr_dev_eval_dict['auto_label_f1']
|
386 |
+
|
387 |
+
if is_best_in_domain:
|
388 |
+
for key, value in curr_dev_eval_dict.items():
|
389 |
+
dev_eval_dict['in_domain'][key] = value
|
390 |
+
curr_test_eval_dict = evaluation(args, datasets['test'], 'test', model, args.domain, epoch, 'current_results')
|
391 |
+
for key, value in curr_test_eval_dict.items():
|
392 |
+
test_eval_dict['in_domain'][key] = value
|
393 |
+
best_model = deepcopy(model)
|
394 |
+
best_optimizer = deepcopy(optimizer)
|
395 |
+
patient = 0
|
396 |
+
else:
|
397 |
+
patient += 1
|
398 |
+
if epoch == args.num_epochs:
|
399 |
+
# save in-domain checkpoint
|
400 |
+
for split in ['dev', 'test']:
|
401 |
+
eval_dict = dev_eval_dict['in_domain'] if split == 'dev' else test_eval_dict['in_domain']
|
402 |
+
write_results(args, datasets[split], args.domain, split, best_model, args.domain, eval_dict)
|
403 |
+
save_checkpoint(best_model, best_optimizer, args.opt, dev_eval_dict, test_eval_dict, args.full_model_name)
|
404 |
+
|
405 |
+
print('\n')
|
406 |
+
return dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient, curr_dev_eval_dict
|
407 |
+
|
408 |
+
|
409 |
+
def evaluation(args, data, split, model, domain, epoch, str_res='results'):
|
410 |
+
# evaluate performance on data
|
411 |
+
model.eval()
|
412 |
+
auto_label_idx2inst = Index2Instance(args.alphabets['auto_label_alphabet'])
|
413 |
+
eval_dict = initialize_eval_dict()
|
414 |
+
eval_dict['epoch'] = epoch
|
415 |
+
pred_labels = []
|
416 |
+
gold_labels = []
|
417 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
418 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
419 |
+
output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
420 |
+
auto_label_preds = model.decode(output, mask=masks, length=lengths, leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
421 |
+
lengths = lengths.cpu().numpy()
|
422 |
+
word = word.data.cpu().numpy()
|
423 |
+
pos = pos.data.cpu().numpy()
|
424 |
+
ner = ner.data.cpu().numpy()
|
425 |
+
heads = heads.data.cpu().numpy()
|
426 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
427 |
+
auto_label = auto_label.data.cpu().numpy()
|
428 |
+
auto_label_preds = auto_label_preds.data.cpu().numpy()
|
429 |
+
gold_labels += auto_label_idx2inst.index2instance(auto_label, lengths, symbolic_root=True)
|
430 |
+
pred_labels += auto_label_idx2inst.index2instance(auto_label_preds, lengths, symbolic_root=True)
|
431 |
+
|
432 |
+
eval_dict['auto_label_accuracy'] = accuracy_score(gold_labels, pred_labels) * 100
|
433 |
+
eval_dict['auto_label_precision'] = precision_score(gold_labels, pred_labels) * 100
|
434 |
+
eval_dict['auto_label_recall'] = recall_score(gold_labels, pred_labels) * 100
|
435 |
+
eval_dict['auto_label_f1'] = f1_score(gold_labels, pred_labels) * 100
|
436 |
+
eval_dict['classification_report'] = classification_report(gold_labels, pred_labels)
|
437 |
+
print_results(eval_dict, split, domain, str_res)
|
438 |
+
return eval_dict
|
439 |
+
|
440 |
+
|
441 |
+
def print_results(eval_dict, split, domain, str_res='results'):
|
442 |
+
print('----------------------------------------------------------------------------------------------------------------------------')
|
443 |
+
print('Testing model on domain %s' % domain)
|
444 |
+
print('--------------- sequence_tagger - %s ---------------' % split)
|
445 |
+
print(
|
446 |
+
str_res + ' on ' + split + ' accuracy: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)'
|
447 |
+
% (eval_dict['auto_label_accuracy'], eval_dict['auto_label_precision'], eval_dict['auto_label_recall'], eval_dict['auto_label_f1'],
|
448 |
+
eval_dict['epoch']))
|
449 |
+
print(eval_dict['classification_report'])
|
450 |
+
|
451 |
+
|
452 |
+
def write_results(args, data, data_domain, split, model, model_domain, eval_dict):
|
453 |
+
str_file = args.full_model_name + '_' + split + '_model_domain_' + model_domain + '_data_domain_' + data_domain
|
454 |
+
res_filename = str_file + '_res.txt'
|
455 |
+
pred_filename = str_file + '_pred.txt'
|
456 |
+
gold_filename = str_file + '_gold.txt'
|
457 |
+
|
458 |
+
# save results dictionary into a file
|
459 |
+
with open(res_filename, 'w') as f:
|
460 |
+
json.dump(eval_dict, f)
|
461 |
+
|
462 |
+
# save predictions and gold labels into files
|
463 |
+
pred_writer = Writer(args.alphabets)
|
464 |
+
gold_writer = Writer(args.alphabets)
|
465 |
+
pred_writer.start(pred_filename)
|
466 |
+
gold_writer.start(gold_filename)
|
467 |
+
for batch in prepare_data.iterate_batch(data, args.batch_size, args.device):
|
468 |
+
word, char, pos, ner, heads, arc_tags, auto_label, masks, lengths = batch
|
469 |
+
output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
470 |
+
auto_label_preds = model.decode(output, mask=masks, length=lengths,
|
471 |
+
leading_symbolic=prepare_data.NUM_SYMBOLIC_TAGS)
|
472 |
+
lengths = lengths.cpu().numpy()
|
473 |
+
word = word.data.cpu().numpy()
|
474 |
+
pos = pos.data.cpu().numpy()
|
475 |
+
ner = ner.data.cpu().numpy()
|
476 |
+
heads = heads.data.cpu().numpy()
|
477 |
+
arc_tags = arc_tags.data.cpu().numpy()
|
478 |
+
auto_label_preds = auto_label_preds.data.cpu().numpy()
|
479 |
+
# writing predictions
|
480 |
+
pred_writer.write(word, pos, ner, heads, arc_tags, lengths, auto_label=auto_label_preds, symbolic_root=True)
|
481 |
+
# writing gold labels
|
482 |
+
gold_writer.write(word, pos, ner, heads, arc_tags, lengths, auto_label=auto_label, symbolic_root=True)
|
483 |
+
|
484 |
+
pred_writer.close()
|
485 |
+
gold_writer.close()
|
486 |
+
|
487 |
+
def main():
|
488 |
+
logger.info("Reading and creating arguments")
|
489 |
+
args = read_arguments()
|
490 |
+
logger.info("Reading Data")
|
491 |
+
datasets = {}
|
492 |
+
for split in args.splits:
|
493 |
+
dataset = prepare_data.read_data_to_variable(args.data_paths[split], args.alphabets, args.device,
|
494 |
+
symbolic_root=True)
|
495 |
+
datasets[split] = dataset
|
496 |
+
|
497 |
+
logger.info("Creating Networks")
|
498 |
+
num_data = sum(datasets['train'][1])
|
499 |
+
model, optimizer, dev_eval_dict, test_eval_dict, start_epoch = build_model_and_optimizer(args)
|
500 |
+
best_model = deepcopy(model)
|
501 |
+
best_optimizer = deepcopy(optimizer)
|
502 |
+
logger.info('Training INFO of in domain %s' % args.domain)
|
503 |
+
logger.info('Training on Dependecy Parsing')
|
504 |
+
print(model)
|
505 |
+
logger.info("train: gamma: %f, batch: %d, clip: %.2f, unk replace: %.2f" % (args.gamma, args.batch_size, args.clip, args.unk_replace))
|
506 |
+
logger.info('number of training samples for %s is: %d' % (args.domain, num_data))
|
507 |
+
logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (args.p_in, args.p_out, args.p_rnn))
|
508 |
+
logger.info("num_epochs: %d" % (args.num_epochs))
|
509 |
+
print('\n')
|
510 |
+
|
511 |
+
if not args.eval_mode:
|
512 |
+
logger.info("Training")
|
513 |
+
num_batches = prepare_data.calc_num_batches(datasets['train'], args.batch_size)
|
514 |
+
lr = args.learning_rate
|
515 |
+
patient = 0
|
516 |
+
terminal_patient = 0
|
517 |
+
decay = 0
|
518 |
+
for epoch in range(start_epoch + 1, args.num_epochs + 1):
|
519 |
+
print('Epoch %d (Training: rnn mode: %s, optimizer: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, decay=%d)): ' % (
|
520 |
+
epoch, args.rnn_mode, args.opt, lr, args.epsilon, args.decay_rate, args.schedule, decay))
|
521 |
+
model.train()
|
522 |
+
total_loss = 0.0
|
523 |
+
total_train_inst = 0.0
|
524 |
+
|
525 |
+
iter = prepare_data.iterate_batch_rand_bucket_choosing(
|
526 |
+
datasets['train'], args.batch_size, args.device, unk_replace=args.unk_replace)
|
527 |
+
start_time = time.time()
|
528 |
+
batch_num = 0
|
529 |
+
for batch_num, batch in enumerate(iter):
|
530 |
+
batch_num = batch_num + 1
|
531 |
+
optimizer.zero_grad()
|
532 |
+
# compute loss of main task
|
533 |
+
word, char, pos, ner_tags, heads, arc_tags, auto_label, masks, lengths = batch
|
534 |
+
output, masks, lengths = model.forward(word, char, pos, mask=masks, length=lengths)
|
535 |
+
loss = model.loss(output, auto_label, mask=masks, length=lengths)
|
536 |
+
|
537 |
+
# update losses
|
538 |
+
num_insts = masks.data.sum() - word.size(0)
|
539 |
+
total_loss += loss.item() * num_insts
|
540 |
+
total_train_inst += num_insts
|
541 |
+
# optimize parameters
|
542 |
+
loss.backward()
|
543 |
+
clip_grad_norm_(model.parameters(), args.clip)
|
544 |
+
optimizer.step()
|
545 |
+
|
546 |
+
time_ave = (time.time() - start_time) / batch_num
|
547 |
+
time_left = (num_batches - batch_num) * time_ave
|
548 |
+
|
549 |
+
# update log
|
550 |
+
if batch_num % 50 == 0:
|
551 |
+
log_info = 'train: %d/%d, domain: %s, total loss: %.2f, time left: %.2fs' % \
|
552 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, time_left)
|
553 |
+
sys.stdout.write(log_info)
|
554 |
+
sys.stdout.write('\n')
|
555 |
+
sys.stdout.flush()
|
556 |
+
print('\n')
|
557 |
+
print('train: %d/%d, domain: %s, total_loss: %.2f, time: %.2fs' %
|
558 |
+
(batch_num, num_batches, args.domain, total_loss / total_train_inst, time.time() - start_time))
|
559 |
+
|
560 |
+
dev_eval_dict, test_eval_dict, best_model, best_optimizer, patient,curr_dev_eval_dict = in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model, best_optimizer, patient)
|
561 |
+
store ={'dev_eval_dict':curr_dev_eval_dict }
|
562 |
+
#############################################
|
563 |
+
str_file = args.full_model_name + '_' +'all_epochs'
|
564 |
+
with open(str_file,'a') as f:
|
565 |
+
f.write(str(store)+'\n')
|
566 |
+
if patient == 0:
|
567 |
+
terminal_patient = 0
|
568 |
+
else:
|
569 |
+
terminal_patient += 1
|
570 |
+
if terminal_patient >= 4 * args.schedule:
|
571 |
+
# Save best model and terminate learning
|
572 |
+
cur_epoch = epoch
|
573 |
+
epoch = args.num_epochs
|
574 |
+
in_domain_evaluation(args, datasets, model, optimizer, dev_eval_dict, test_eval_dict, epoch, best_model,
|
575 |
+
best_optimizer, patient)
|
576 |
+
log_info = 'Terminating training in epoch %d' % (cur_epoch)
|
577 |
+
sys.stdout.write(log_info)
|
578 |
+
sys.stdout.write('\n')
|
579 |
+
sys.stdout.flush()
|
580 |
+
return
|
581 |
+
if patient >= args.schedule:
|
582 |
+
lr = args.learning_rate / (1.0 + epoch * args.decay_rate)
|
583 |
+
optimizer = generate_optimizer(args, lr, model.parameters())
|
584 |
+
print('updated learning rate to %.6f' % lr)
|
585 |
+
patient = 0
|
586 |
+
print_results(test_eval_dict['in_domain'], 'test', args.domain, 'best_results')
|
587 |
+
print('\n')
|
588 |
+
|
589 |
+
else:
|
590 |
+
logger.info("Evaluating")
|
591 |
+
epoch = start_epoch
|
592 |
+
for split in ['train', 'dev', 'test']:
|
593 |
+
evaluation(args, datasets[split], split, model, args.domain, epoch, 'best_results')
|
594 |
+
|
595 |
+
|
596 |
+
if __name__ == '__main__':
|
597 |
+
main()
|
examples/eval/conll03eval.v2
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/perl -w
|
2 |
+
# conlleval: evaluate result of processing CoNLL-2000 shared task
|
3 |
+
# usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file
|
4 |
+
# README: http://cnts.uia.ac.be/conll2000/chunking/output.html
|
5 |
+
# options: l: generate LaTeX output for tables like in
|
6 |
+
# http://cnts.uia.ac.be/conll2003/ner/example.tex
|
7 |
+
# r: accept raw result tags (without B- and I- prefix;
|
8 |
+
# assumes one word per chunk)
|
9 |
+
# d: alternative delimiter tag (default is single space)
|
10 |
+
# o: alternative outside tag (default is O)
|
11 |
+
# note: the file should contain lines with items separated
|
12 |
+
# by $delimiter characters (default space). The final
|
13 |
+
# two items should contain the correct tag and the
|
14 |
+
# guessed tag in that order. Sentences should be
|
15 |
+
# separated from each other by empty lines or lines
|
16 |
+
# with $boundary fields (default -X-).
|
17 |
+
# url: http://lcg-www.uia.ac.be/conll2000/chunking/
|
18 |
+
# started: 1998-09-25
|
19 |
+
# version: 2004-01-26
|
20 |
+
# author: Erik Tjong Kim Sang <erikt@uia.ua.ac.be>
|
21 |
+
|
22 |
+
use strict;
|
23 |
+
|
24 |
+
my $false = 0;
|
25 |
+
my $true = 42;
|
26 |
+
|
27 |
+
my $boundary = "-X-"; # sentence boundary
|
28 |
+
my $correct; # current corpus chunk tag (I,O,B)
|
29 |
+
my $correctChunk = 0; # number of correctly identified chunks
|
30 |
+
my $correctTags = 0; # number of correct chunk tags
|
31 |
+
my $correctType; # type of current corpus chunk tag (NP,VP,etc.)
|
32 |
+
my $delimiter = " "; # field delimiter
|
33 |
+
my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979)
|
34 |
+
my $firstItem; # first feature (for sentence boundary checks)
|
35 |
+
my $foundCorrect = 0; # number of chunks in corpus
|
36 |
+
my $foundGuessed = 0; # number of identified chunks
|
37 |
+
my $guessed; # current guessed chunk tag
|
38 |
+
my $guessedType; # type of current guessed chunk tag
|
39 |
+
my $i; # miscellaneous counter
|
40 |
+
my $inCorrect = $false; # currently processed chunk is correct until now
|
41 |
+
my $lastCorrect = "O"; # previous chunk tag in corpus
|
42 |
+
my $latex = 0; # generate LaTeX formatted output
|
43 |
+
my $lastCorrectType = ""; # type of previously identified chunk tag
|
44 |
+
my $lastGuessed = "O"; # previously identified chunk tag
|
45 |
+
my $lastGuessedType = ""; # type of previous chunk tag in corpus
|
46 |
+
my $lastType; # temporary storage for detecting duplicates
|
47 |
+
my $line; # line
|
48 |
+
my $nbrOfFeatures = -1; # number of features per line
|
49 |
+
my $precision = 0.0; # precision score
|
50 |
+
my $oTag = "O"; # outside tag, default O
|
51 |
+
my $raw = 0; # raw input: add B to every token
|
52 |
+
my $recall = 0.0; # recall score
|
53 |
+
my $tokenCounter = 0; # token counter (ignores sentence breaks)
|
54 |
+
|
55 |
+
my %correctChunk = (); # number of correctly identified chunks per type
|
56 |
+
my %foundCorrect = (); # number of chunks in corpus per type
|
57 |
+
my %foundGuessed = (); # number of identified chunks per type
|
58 |
+
|
59 |
+
my @features; # features on line
|
60 |
+
my @sortedTypes; # sorted list of chunk type names
|
61 |
+
|
62 |
+
# sanity check
|
63 |
+
while (@ARGV and $ARGV[0] =~ /^-/) {
|
64 |
+
if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); }
|
65 |
+
elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); }
|
66 |
+
elsif ($ARGV[0] eq "-d") {
|
67 |
+
shift(@ARGV);
|
68 |
+
if (not defined $ARGV[0]) {
|
69 |
+
die "conlleval: -d requires delimiter character";
|
70 |
+
}
|
71 |
+
$delimiter = shift(@ARGV);
|
72 |
+
} elsif ($ARGV[0] eq "-o") {
|
73 |
+
shift(@ARGV);
|
74 |
+
if (not defined $ARGV[0]) {
|
75 |
+
die "conlleval: -o requires delimiter character";
|
76 |
+
}
|
77 |
+
$oTag = shift(@ARGV);
|
78 |
+
} else { die "conlleval: unknown argument $ARGV[0]\n"; }
|
79 |
+
}
|
80 |
+
if (@ARGV) { die "conlleval: unexpected command line argument\n"; }
|
81 |
+
# process input
|
82 |
+
while (<STDIN>) {
|
83 |
+
chomp($line = $_);
|
84 |
+
@features = split(/$delimiter/,$line);
|
85 |
+
if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; }
|
86 |
+
elsif ($nbrOfFeatures != $#features and @features != 0) {
|
87 |
+
printf STDERR "unexpected number of features: %d (%d)\n",
|
88 |
+
$#features+1,$nbrOfFeatures+1;
|
89 |
+
exit(1);
|
90 |
+
}
|
91 |
+
if (@features == 0 or
|
92 |
+
$features[0] eq $boundary) { @features = ($boundary,"O","O"); }
|
93 |
+
if (@features < 2) {
|
94 |
+
die "conlleval: unexpected number of features in line $line\n";
|
95 |
+
}
|
96 |
+
if ($raw) {
|
97 |
+
if ($features[$#features] eq $oTag) { $features[$#features] = "O"; }
|
98 |
+
if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; }
|
99 |
+
if ($features[$#features] ne "O") {
|
100 |
+
$features[$#features] = "B-$features[$#features]";
|
101 |
+
}
|
102 |
+
if ($features[$#features-1] ne "O") {
|
103 |
+
$features[$#features-1] = "B-$features[$#features-1]";
|
104 |
+
}
|
105 |
+
}
|
106 |
+
# 20040126 ET code which allows hyphens in the types
|
107 |
+
if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
|
108 |
+
$guessed = $1;
|
109 |
+
$guessedType = $2;
|
110 |
+
} else {
|
111 |
+
$guessed = $features[$#features];
|
112 |
+
$guessedType = "";
|
113 |
+
}
|
114 |
+
pop(@features);
|
115 |
+
if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
|
116 |
+
$correct = $1;
|
117 |
+
$correctType = $2;
|
118 |
+
} else {
|
119 |
+
$correct = $features[$#features];
|
120 |
+
$correctType = "";
|
121 |
+
}
|
122 |
+
pop(@features);
|
123 |
+
# ($guessed,$guessedType) = split(/-/,pop(@features));
|
124 |
+
# ($correct,$correctType) = split(/-/,pop(@features));
|
125 |
+
$guessedType = $guessedType ? $guessedType : "";
|
126 |
+
$correctType = $correctType ? $correctType : "";
|
127 |
+
$firstItem = shift(@features);
|
128 |
+
|
129 |
+
# 1999-06-26 sentence breaks should always be counted as out of chunk
|
130 |
+
if ( $firstItem eq $boundary ) { $guessed = "O"; }
|
131 |
+
|
132 |
+
if ($inCorrect) {
|
133 |
+
if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
|
134 |
+
&endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
|
135 |
+
$lastGuessedType eq $lastCorrectType) {
|
136 |
+
$inCorrect=$false;
|
137 |
+
$correctChunk++;
|
138 |
+
$correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
|
139 |
+
$correctChunk{$lastCorrectType}+1 : 1;
|
140 |
+
} elsif (
|
141 |
+
&endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) !=
|
142 |
+
&endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or
|
143 |
+
$guessedType ne $correctType ) {
|
144 |
+
$inCorrect=$false;
|
145 |
+
}
|
146 |
+
}
|
147 |
+
|
148 |
+
if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
|
149 |
+
&startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
|
150 |
+
$guessedType eq $correctType) { $inCorrect = $true; }
|
151 |
+
|
152 |
+
if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) {
|
153 |
+
$foundCorrect++;
|
154 |
+
$foundCorrect{$correctType} = $foundCorrect{$correctType} ?
|
155 |
+
$foundCorrect{$correctType}+1 : 1;
|
156 |
+
}
|
157 |
+
if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) {
|
158 |
+
$foundGuessed++;
|
159 |
+
$foundGuessed{$guessedType} = $foundGuessed{$guessedType} ?
|
160 |
+
$foundGuessed{$guessedType}+1 : 1;
|
161 |
+
}
|
162 |
+
if ( $firstItem ne $boundary ) {
|
163 |
+
if ( $correct eq $guessed and $guessedType eq $correctType ) {
|
164 |
+
$correctTags++;
|
165 |
+
}
|
166 |
+
$tokenCounter++;
|
167 |
+
}
|
168 |
+
|
169 |
+
$lastGuessed = $guessed;
|
170 |
+
$lastCorrect = $correct;
|
171 |
+
$lastGuessedType = $guessedType;
|
172 |
+
$lastCorrectType = $correctType;
|
173 |
+
}
|
174 |
+
if ($inCorrect) {
|
175 |
+
$correctChunk++;
|
176 |
+
$correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
|
177 |
+
$correctChunk{$lastCorrectType}+1 : 1;
|
178 |
+
}
|
179 |
+
|
180 |
+
if (not $latex) {
|
181 |
+
# compute overall precision, recall and FB1 (default values are 0.0)
|
182 |
+
$precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
|
183 |
+
$recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
|
184 |
+
$FB1 = 2*$precision*$recall/($precision+$recall)
|
185 |
+
if ($precision+$recall > 0);
|
186 |
+
|
187 |
+
# print overall performance
|
188 |
+
printf "processed $tokenCounter tokens with $foundCorrect phrases; ";
|
189 |
+
printf "found: $foundGuessed phrases; correct: $correctChunk.\n";
|
190 |
+
if ($tokenCounter>0) {
|
191 |
+
printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter;
|
192 |
+
printf "precision: %6.2f%%; ",$precision;
|
193 |
+
printf "recall: %6.2f%%; ",$recall;
|
194 |
+
printf "FB1: %6.2f\n",$FB1;
|
195 |
+
}
|
196 |
+
}
|
197 |
+
|
198 |
+
# sort chunk type names
|
199 |
+
undef($lastType);
|
200 |
+
@sortedTypes = ();
|
201 |
+
foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) {
|
202 |
+
if (not($lastType) or $lastType ne $i) {
|
203 |
+
push(@sortedTypes,($i));
|
204 |
+
}
|
205 |
+
$lastType = $i;
|
206 |
+
}
|
207 |
+
# print performance per chunk type
|
208 |
+
if (not $latex) {
|
209 |
+
for $i (@sortedTypes) {
|
210 |
+
$correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
|
211 |
+
if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; }
|
212 |
+
else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
|
213 |
+
if (not($foundCorrect{$i})) { $recall = 0.0; }
|
214 |
+
else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
|
215 |
+
if ($precision+$recall == 0.0) { $FB1 = 0.0; }
|
216 |
+
else { $FB1 = 2*$precision*$recall/($precision+$recall); }
|
217 |
+
printf "%17s: ",$i;
|
218 |
+
printf "precision: %6.2f%%; ",$precision;
|
219 |
+
printf "recall: %6.2f%%; ",$recall;
|
220 |
+
printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i};
|
221 |
+
}
|
222 |
+
} else {
|
223 |
+
print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline";
|
224 |
+
for $i (@sortedTypes) {
|
225 |
+
$correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
|
226 |
+
if (not($foundGuessed{$i})) { $precision = 0.0; }
|
227 |
+
else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
|
228 |
+
if (not($foundCorrect{$i})) { $recall = 0.0; }
|
229 |
+
else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
|
230 |
+
if ($precision+$recall == 0.0) { $FB1 = 0.0; }
|
231 |
+
else { $FB1 = 2*$precision*$recall/($precision+$recall); }
|
232 |
+
printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\",
|
233 |
+
$i,$precision,$recall,$FB1;
|
234 |
+
}
|
235 |
+
print "\\hline\n";
|
236 |
+
$precision = 0.0;
|
237 |
+
$recall = 0;
|
238 |
+
$FB1 = 0.0;
|
239 |
+
$precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
|
240 |
+
$recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
|
241 |
+
$FB1 = 2*$precision*$recall/($precision+$recall)
|
242 |
+
if ($precision+$recall > 0);
|
243 |
+
printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n",
|
244 |
+
$precision,$recall,$FB1;
|
245 |
+
}
|
246 |
+
|
247 |
+
exit 0;
|
248 |
+
|
249 |
+
# endOfChunk: checks if a chunk ended between the previous and current word
|
250 |
+
# arguments: previous and current chunk tags, previous and current types
|
251 |
+
# note: this code is capable of handling other chunk representations
|
252 |
+
# than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
|
253 |
+
# Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
|
254 |
+
|
255 |
+
sub endOfChunk {
|
256 |
+
my $prevTag = shift(@_);
|
257 |
+
my $tag = shift(@_);
|
258 |
+
my $prevType = shift(@_);
|
259 |
+
my $type = shift(@_);
|
260 |
+
my $chunkEnd = $false;
|
261 |
+
|
262 |
+
if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; }
|
263 |
+
if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; }
|
264 |
+
if ( $prevTag eq "B" and $tag eq "S" ) { $chunkEnd = $true; }
|
265 |
+
|
266 |
+
if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; }
|
267 |
+
if ( $prevTag eq "I" and $tag eq "S" ) { $chunkEnd = $true; }
|
268 |
+
if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
|
269 |
+
|
270 |
+
if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; }
|
271 |
+
if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; }
|
272 |
+
if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; }
|
273 |
+
if ( $prevTag eq "E" and $tag eq "S" ) { $chunkEnd = $true; }
|
274 |
+
if ( $prevTag eq "E" and $tag eq "B" ) { $chunkEnd = $true; }
|
275 |
+
|
276 |
+
if ( $prevTag eq "S" and $tag eq "E" ) { $chunkEnd = $true; }
|
277 |
+
if ( $prevTag eq "S" and $tag eq "I" ) { $chunkEnd = $true; }
|
278 |
+
if ( $prevTag eq "S" and $tag eq "O" ) { $chunkEnd = $true; }
|
279 |
+
if ( $prevTag eq "S" and $tag eq "S" ) { $chunkEnd = $true; }
|
280 |
+
if ( $prevTag eq "S" and $tag eq "B" ) { $chunkEnd = $true; }
|
281 |
+
|
282 |
+
|
283 |
+
if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) {
|
284 |
+
$chunkEnd = $true;
|
285 |
+
}
|
286 |
+
|
287 |
+
# corrected 1998-12-22: these chunks are assumed to have length 1
|
288 |
+
if ( $prevTag eq "]" ) { $chunkEnd = $true; }
|
289 |
+
if ( $prevTag eq "[" ) { $chunkEnd = $true; }
|
290 |
+
|
291 |
+
return($chunkEnd);
|
292 |
+
}
|
293 |
+
|
294 |
+
# startOfChunk: checks if a chunk started between the previous and current word
|
295 |
+
# arguments: previous and current chunk tags, previous and current types
|
296 |
+
# note: this code is capable of handling other chunk representations
|
297 |
+
# than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
|
298 |
+
# Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
|
299 |
+
|
300 |
+
sub startOfChunk {
|
301 |
+
my $prevTag = shift(@_);
|
302 |
+
my $tag = shift(@_);
|
303 |
+
my $prevType = shift(@_);
|
304 |
+
my $type = shift(@_);
|
305 |
+
my $chunkStart = $false;
|
306 |
+
|
307 |
+
if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; }
|
308 |
+
if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; }
|
309 |
+
if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; }
|
310 |
+
if ( $prevTag eq "S" and $tag eq "B" ) { $chunkStart = $true; }
|
311 |
+
if ( $prevTag eq "E" and $tag eq "B" ) { $chunkStart = $true; }
|
312 |
+
|
313 |
+
if ( $prevTag eq "B" and $tag eq "S" ) { $chunkStart = $true; }
|
314 |
+
if ( $prevTag eq "I" and $tag eq "S" ) { $chunkStart = $true; }
|
315 |
+
if ( $prevTag eq "O" and $tag eq "S" ) { $chunkStart = $true; }
|
316 |
+
if ( $prevTag eq "S" and $tag eq "S" ) { $chunkStart = $true; }
|
317 |
+
if ( $prevTag eq "E" and $tag eq "S" ) { $chunkStart = $true; }
|
318 |
+
|
319 |
+
if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
|
320 |
+
if ( $prevTag eq "S" and $tag eq "I" ) { $chunkStart = $true; }
|
321 |
+
if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; }
|
322 |
+
|
323 |
+
if ( $prevTag eq "S" and $tag eq "E" ) { $chunkStart = $true; }
|
324 |
+
if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; }
|
325 |
+
if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; }
|
326 |
+
|
327 |
+
if ($tag ne "O" and $tag ne "." and $prevType ne $type) {
|
328 |
+
$chunkStart = $true;
|
329 |
+
}
|
330 |
+
|
331 |
+
# corrected 1998-12-22: these chunks are assumed to have length 1
|
332 |
+
if ( $tag eq "[" ) { $chunkStart = $true; }
|
333 |
+
if ( $tag eq "]" ) { $chunkStart = $true; }
|
334 |
+
|
335 |
+
return($chunkStart);
|
336 |
+
}
|
examples/eval/conll06eval.pl
ADDED
@@ -0,0 +1,1826 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env perl
|
2 |
+
|
3 |
+
# Author: Yuval Krymolowski
|
4 |
+
# Addition of precision and recall
|
5 |
+
# and of frame confusion list: Sabine Buchholz
|
6 |
+
# Addition of DEPREL + ATTACHMENT:
|
7 |
+
# Prokopis Prokopidis (prokopis at ilsp dot gr)
|
8 |
+
# Acknowledgements:
|
9 |
+
# to Markus Kuhn for suggesting the use of
|
10 |
+
# the Unicode category property
|
11 |
+
|
12 |
+
if ($] < 5.008001)
|
13 |
+
{
|
14 |
+
printf STDERR <<EOM
|
15 |
+
|
16 |
+
This script requires PERL 5.8.1 for running.
|
17 |
+
The new version is needed for proper handling
|
18 |
+
of Unicode characters.
|
19 |
+
|
20 |
+
Please obtain a new version or contact the shared task team
|
21 |
+
if you are unable to upgrade PERL.
|
22 |
+
|
23 |
+
EOM
|
24 |
+
;
|
25 |
+
exit(1) ;
|
26 |
+
}
|
27 |
+
|
28 |
+
require Encode;
|
29 |
+
|
30 |
+
use strict ;
|
31 |
+
use warnings;
|
32 |
+
use Getopt::Std ;
|
33 |
+
|
34 |
+
my ($usage) = <<EOT
|
35 |
+
|
36 |
+
CoNLL-X evaluation script:
|
37 |
+
|
38 |
+
[perl] eval.pl [OPTIONS] -g <gold standard> -s <system output>
|
39 |
+
|
40 |
+
This script evaluates a system output with respect to a gold standard.
|
41 |
+
Both files should be in UTF-8 encoded CoNLL-X tabular format.
|
42 |
+
|
43 |
+
Punctuation tokens (those where all characters have the Unicode
|
44 |
+
category property "Punctuation") are ignored for scoring (unless the
|
45 |
+
-p flag is used).
|
46 |
+
|
47 |
+
The output breaks down the errors according to their type and context.
|
48 |
+
|
49 |
+
Optional parameters:
|
50 |
+
-o FILE : output: print output to FILE (default is standard output)
|
51 |
+
-q : quiet: only print overall performance, without the details
|
52 |
+
-b : evalb: produce output in a format similar to evalb
|
53 |
+
(http://nlp.cs.nyu.edu/evalb/); use together with -q
|
54 |
+
-p : punctuation: also score on punctuation (default is not to score on it)
|
55 |
+
-v : version: show the version number
|
56 |
+
-h : help: print this help text and exit
|
57 |
+
|
58 |
+
EOT
|
59 |
+
;
|
60 |
+
|
61 |
+
my ($line_num) ;
|
62 |
+
my ($sep) = '0x01' ;
|
63 |
+
|
64 |
+
my ($START) = '.S' ;
|
65 |
+
my ($END) = '.E' ;
|
66 |
+
|
67 |
+
my ($con_err_num) = 3 ;
|
68 |
+
my ($freq_err_num) = 10 ;
|
69 |
+
my ($spec_err_loc_con) = 8 ;
|
70 |
+
|
71 |
+
################################################################################
|
72 |
+
### subfunctions ###
|
73 |
+
################################################################################
|
74 |
+
|
75 |
+
# Whether a string consists entirely of characters with the Unicode
|
76 |
+
# category property "Punctuation" (see "man perlunicode")
|
77 |
+
sub is_uni_punct
|
78 |
+
{
|
79 |
+
my ($word) = @_ ;
|
80 |
+
|
81 |
+
return scalar(Encode::decode_utf8($word)=~ /^\p{Punctuation}+$/) ;
|
82 |
+
}
|
83 |
+
|
84 |
+
# The length of a unicode string, excluding non-spacing marks
|
85 |
+
# (for example vowel marks in Arabic)
|
86 |
+
|
87 |
+
sub uni_len
|
88 |
+
{
|
89 |
+
my ($word) = @_ ;
|
90 |
+
my ($ch, $l) ;
|
91 |
+
|
92 |
+
$l = 0 ;
|
93 |
+
foreach $ch (split(//, Encode::decode_utf8($word)))
|
94 |
+
{
|
95 |
+
if ($ch !~ /^\p{NonspacingMark}/)
|
96 |
+
{
|
97 |
+
$l++ ;
|
98 |
+
}
|
99 |
+
}
|
100 |
+
|
101 |
+
return $l ;
|
102 |
+
}
|
103 |
+
|
104 |
+
sub filter_context_counts
|
105 |
+
{ # filter_context_counts
|
106 |
+
|
107 |
+
my ($vec, $num, $max_len) = @_ ;
|
108 |
+
my ($con, $l, $thresh) ;
|
109 |
+
|
110 |
+
$thresh = (sort {$b <=> $a} values %{$vec})[$num-1] ;
|
111 |
+
|
112 |
+
foreach $con (keys %{$vec})
|
113 |
+
{
|
114 |
+
if (${$vec}{$con} < $thresh)
|
115 |
+
{
|
116 |
+
delete ${$vec}{$con} ;
|
117 |
+
next ;
|
118 |
+
}
|
119 |
+
|
120 |
+
$l = uni_len($con) ;
|
121 |
+
|
122 |
+
if ($l > ${$max_len})
|
123 |
+
{
|
124 |
+
${$max_len} = $l ;
|
125 |
+
}
|
126 |
+
}
|
127 |
+
|
128 |
+
} # filter_context_counts
|
129 |
+
|
130 |
+
sub print_context
|
131 |
+
{ # print_context
|
132 |
+
|
133 |
+
my ($counts, $counts_pos, $max_con_len, $max_con_pos_len) = @_ ;
|
134 |
+
my (@v_con, @v_con_pos, $con, $con_pos, $i, $n) ;
|
135 |
+
|
136 |
+
printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s", $max_con_pos_len, 'CPOS', 'any', 'head', 'dep', 'both' ;
|
137 |
+
printf OUT " ||" ;
|
138 |
+
printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s", $max_con_len, 'word', 'any', 'head', 'dep', 'both' ;
|
139 |
+
printf OUT "\n" ;
|
140 |
+
printf OUT " %s-+------+------+------+-----", '-' x $max_con_pos_len;
|
141 |
+
printf OUT "--++" ;
|
142 |
+
printf OUT "--%s-+------+------+------+-----", '-' x $max_con_len;
|
143 |
+
printf OUT "\n" ;
|
144 |
+
|
145 |
+
@v_con = sort {${$counts}{tot}{$b} <=> ${$counts}{tot}{$a}} keys %{${$counts}{tot}} ;
|
146 |
+
@v_con_pos = sort {${$counts_pos}{tot}{$b} <=> ${$counts_pos}{tot}{$a}} keys %{${$counts_pos}{tot}} ;
|
147 |
+
|
148 |
+
$n = scalar @v_con ;
|
149 |
+
if (scalar @v_con_pos > $n)
|
150 |
+
{
|
151 |
+
$n = scalar @v_con_pos ;
|
152 |
+
}
|
153 |
+
|
154 |
+
foreach $i (0 .. $n-1)
|
155 |
+
{
|
156 |
+
if (defined $v_con_pos[$i])
|
157 |
+
{
|
158 |
+
$con_pos = $v_con_pos[$i] ;
|
159 |
+
printf OUT " %-*s | %4d | %4d | %4d | %4d",
|
160 |
+
$max_con_pos_len, $con_pos, ${$counts_pos}{tot}{$con_pos},
|
161 |
+
${$counts_pos}{err_head}{$con_pos}, ${$counts_pos}{err_dep}{$con_pos},
|
162 |
+
${$counts_pos}{err_dep}{$con_pos}+${$counts_pos}{err_head}{$con_pos}-${$counts_pos}{tot}{$con_pos} ;
|
163 |
+
}
|
164 |
+
else
|
165 |
+
{
|
166 |
+
printf OUT " %-*s | %4s | %4s | %4s | %4s",
|
167 |
+
$max_con_pos_len, ' ', ' ', ' ', ' ', ' ' ;
|
168 |
+
}
|
169 |
+
|
170 |
+
printf OUT " ||" ;
|
171 |
+
|
172 |
+
if (defined $v_con[$i])
|
173 |
+
{
|
174 |
+
$con = $v_con[$i] ;
|
175 |
+
printf OUT " %-*s | %4d | %4d | %4d | %4d",
|
176 |
+
$max_con_len+length($con)-uni_len($con), $con, ${$counts}{tot}{$con},
|
177 |
+
${$counts}{err_head}{$con}, ${$counts}{err_dep}{$con},
|
178 |
+
${$counts}{err_dep}{$con}+${$counts}{err_head}{$con}-${$counts}{tot}{$con} ;
|
179 |
+
}
|
180 |
+
else
|
181 |
+
{
|
182 |
+
printf OUT " %-*s | %4s | %4s | %4s | %4s",
|
183 |
+
$max_con_len, ' ', ' ', ' ', ' ', ' ' ;
|
184 |
+
}
|
185 |
+
|
186 |
+
printf OUT "\n" ;
|
187 |
+
}
|
188 |
+
|
189 |
+
printf OUT " %s-+------+------+------+-----", '-' x $max_con_pos_len;
|
190 |
+
printf OUT "--++" ;
|
191 |
+
printf OUT "--%s-+------+------+------+-----", '-' x $max_con_len;
|
192 |
+
printf OUT "\n" ;
|
193 |
+
|
194 |
+
printf OUT "\n\n" ;
|
195 |
+
|
196 |
+
} # print_context
|
197 |
+
|
198 |
+
sub num_as_word
|
199 |
+
{
|
200 |
+
my ($num) = @_ ;
|
201 |
+
|
202 |
+
$num = abs($num) ;
|
203 |
+
|
204 |
+
if ($num == 1)
|
205 |
+
{
|
206 |
+
return ('one word') ;
|
207 |
+
}
|
208 |
+
elsif ($num == 2)
|
209 |
+
{
|
210 |
+
return ('two words') ;
|
211 |
+
}
|
212 |
+
elsif ($num == 3)
|
213 |
+
{
|
214 |
+
return ('three words') ;
|
215 |
+
}
|
216 |
+
elsif ($num == 4)
|
217 |
+
{
|
218 |
+
return ('four words') ;
|
219 |
+
}
|
220 |
+
else
|
221 |
+
{
|
222 |
+
return ($num.' words') ;
|
223 |
+
}
|
224 |
+
}
|
225 |
+
|
226 |
+
sub describe_err
|
227 |
+
{ # describe_err
|
228 |
+
|
229 |
+
my ($head_err, $head_aft_bef, $dep_err) = @_ ;
|
230 |
+
my ($dep_g, $dep_s, $desc) ;
|
231 |
+
my ($head_aft_bef_g, $head_aft_bef_s) = split(//, $head_aft_bef) ;
|
232 |
+
|
233 |
+
if ($head_err eq '-')
|
234 |
+
{
|
235 |
+
$desc = 'correct head' ;
|
236 |
+
|
237 |
+
if ($head_aft_bef_s eq '0')
|
238 |
+
{
|
239 |
+
$desc .= ' (0)' ;
|
240 |
+
}
|
241 |
+
elsif ($head_aft_bef_s eq 'e')
|
242 |
+
{
|
243 |
+
$desc .= ' (the focus word)' ;
|
244 |
+
}
|
245 |
+
elsif ($head_aft_bef_s eq 'a')
|
246 |
+
{
|
247 |
+
$desc .= ' (after the focus word)' ;
|
248 |
+
}
|
249 |
+
elsif ($head_aft_bef_s eq 'b')
|
250 |
+
{
|
251 |
+
$desc .= ' (before the focus word)' ;
|
252 |
+
}
|
253 |
+
}
|
254 |
+
elsif ($head_aft_bef_s eq '0')
|
255 |
+
{
|
256 |
+
$desc = 'head = 0 instead of ' ;
|
257 |
+
if ($head_aft_bef_g eq 'a')
|
258 |
+
{
|
259 |
+
$desc.= 'after ' ;
|
260 |
+
}
|
261 |
+
if ($head_aft_bef_g eq 'b')
|
262 |
+
{
|
263 |
+
$desc.= 'before ' ;
|
264 |
+
}
|
265 |
+
$desc .= 'the focus word' ;
|
266 |
+
}
|
267 |
+
elsif ($head_aft_bef_g eq '0')
|
268 |
+
{
|
269 |
+
$desc = 'head is ' ;
|
270 |
+
if ($head_aft_bef_g eq 'a')
|
271 |
+
{
|
272 |
+
$desc.= 'after ' ;
|
273 |
+
}
|
274 |
+
if ($head_aft_bef_g eq 'b')
|
275 |
+
{
|
276 |
+
$desc.= 'before ' ;
|
277 |
+
}
|
278 |
+
$desc .= 'the focus word instead of 0' ;
|
279 |
+
}
|
280 |
+
else
|
281 |
+
{
|
282 |
+
$desc = num_as_word($head_err) ;
|
283 |
+
if ($head_err < 0)
|
284 |
+
{
|
285 |
+
$desc .= ' before' ;
|
286 |
+
}
|
287 |
+
else
|
288 |
+
{
|
289 |
+
$desc .= ' after' ;
|
290 |
+
}
|
291 |
+
|
292 |
+
$desc = 'head '.$desc.' the correct head ' ;
|
293 |
+
|
294 |
+
if ($head_aft_bef_s eq '0')
|
295 |
+
{
|
296 |
+
$desc .= '(0' ;
|
297 |
+
}
|
298 |
+
elsif ($head_aft_bef_s eq 'e')
|
299 |
+
{
|
300 |
+
$desc .= '(the focus word' ;
|
301 |
+
}
|
302 |
+
elsif ($head_aft_bef_s eq 'a')
|
303 |
+
{
|
304 |
+
$desc .= '(after the focus word' ;
|
305 |
+
}
|
306 |
+
elsif ($head_aft_bef_s eq 'b')
|
307 |
+
{
|
308 |
+
$desc .= '(before the focus word' ;
|
309 |
+
}
|
310 |
+
|
311 |
+
if ($head_aft_bef_g ne $head_aft_bef_s)
|
312 |
+
{
|
313 |
+
$desc .= ' instead of' ;
|
314 |
+
if ($head_aft_bef_s eq '0')
|
315 |
+
{
|
316 |
+
$desc .= '0' ;
|
317 |
+
}
|
318 |
+
elsif ($head_aft_bef_s eq 'e')
|
319 |
+
{
|
320 |
+
$desc .= 'the focus word' ;
|
321 |
+
}
|
322 |
+
elsif ($head_aft_bef_s eq 'a')
|
323 |
+
{
|
324 |
+
$desc .= 'after the focus word' ;
|
325 |
+
}
|
326 |
+
elsif ($head_aft_bef_s eq 'b')
|
327 |
+
{
|
328 |
+
$desc .= 'before the focus word' ;
|
329 |
+
}
|
330 |
+
}
|
331 |
+
|
332 |
+
$desc .= ')' ;
|
333 |
+
}
|
334 |
+
|
335 |
+
$desc .= ', ' ;
|
336 |
+
|
337 |
+
if ($dep_err eq '-')
|
338 |
+
{
|
339 |
+
$desc .= 'correct dependency' ;
|
340 |
+
}
|
341 |
+
else
|
342 |
+
{
|
343 |
+
($dep_g, $dep_s) = ($dep_err =~ /^(.*)->(.*)$/) ;
|
344 |
+
$desc .= sprintf('dependency "%s" instead of "%s"', $dep_s, $dep_g) ;
|
345 |
+
}
|
346 |
+
|
347 |
+
return($desc) ;
|
348 |
+
|
349 |
+
} # describe_err
|
350 |
+
|
351 |
+
sub get_context
|
352 |
+
{ # get_context
|
353 |
+
|
354 |
+
my ($sent, $i_w) = @_ ;
|
355 |
+
my ($w_2, $w_1, $w1, $w2) ;
|
356 |
+
my ($p_2, $p_1, $p1, $p2) ;
|
357 |
+
|
358 |
+
if ($i_w >= 2)
|
359 |
+
{
|
360 |
+
$w_2 = ${${$sent}[$i_w-2]}{word} ;
|
361 |
+
$p_2 = ${${$sent}[$i_w-2]}{pos} ;
|
362 |
+
}
|
363 |
+
else
|
364 |
+
{
|
365 |
+
$w_2 = $START ;
|
366 |
+
$p_2 = $START ;
|
367 |
+
}
|
368 |
+
|
369 |
+
if ($i_w >= 1)
|
370 |
+
{
|
371 |
+
$w_1 = ${${$sent}[$i_w-1]}{word} ;
|
372 |
+
$p_1 = ${${$sent}[$i_w-1]}{pos} ;
|
373 |
+
}
|
374 |
+
else
|
375 |
+
{
|
376 |
+
$w_1 = $START ;
|
377 |
+
$p_1 = $START ;
|
378 |
+
}
|
379 |
+
|
380 |
+
if ($i_w <= scalar @{$sent}-2)
|
381 |
+
{
|
382 |
+
$w1 = ${${$sent}[$i_w+1]}{word} ;
|
383 |
+
$p1 = ${${$sent}[$i_w+1]}{pos} ;
|
384 |
+
}
|
385 |
+
else
|
386 |
+
{
|
387 |
+
$w1 = $END ;
|
388 |
+
$p1 = $END ;
|
389 |
+
}
|
390 |
+
|
391 |
+
if ($i_w <= scalar @{$sent}-3)
|
392 |
+
{
|
393 |
+
$w2 = ${${$sent}[$i_w+2]}{word} ;
|
394 |
+
$p2 = ${${$sent}[$i_w+2]}{pos} ;
|
395 |
+
}
|
396 |
+
else
|
397 |
+
{
|
398 |
+
$w2 = $END ;
|
399 |
+
$p2 = $END ;
|
400 |
+
}
|
401 |
+
|
402 |
+
return ($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) ;
|
403 |
+
|
404 |
+
} # get_context
|
405 |
+
|
406 |
+
sub read_sent
|
407 |
+
{ # read_sent
|
408 |
+
|
409 |
+
my ($sent_gold, $sent_sys) = @_ ;
|
410 |
+
my ($line_g, $line_s, $new_sent) ;
|
411 |
+
my (%fields_g, %fields_s) ;
|
412 |
+
|
413 |
+
$new_sent = 1 ;
|
414 |
+
|
415 |
+
@{$sent_gold} = () ;
|
416 |
+
@{$sent_sys} = () ;
|
417 |
+
|
418 |
+
while (1)
|
419 |
+
{ # main reading loop
|
420 |
+
|
421 |
+
$line_g = <GOLD> ;
|
422 |
+
$line_s = <SYS> ;
|
423 |
+
|
424 |
+
$line_num++ ;
|
425 |
+
|
426 |
+
# system output has fewer lines than gold standard
|
427 |
+
if ((defined $line_g) && (! defined $line_s))
|
428 |
+
{
|
429 |
+
printf STDERR "line mismatch, line %d:\n", $line_num ;
|
430 |
+
printf STDERR " gold: %s", $line_g ;
|
431 |
+
printf STDERR " sys : past end of file\n" ;
|
432 |
+
exit(1) ;
|
433 |
+
}
|
434 |
+
|
435 |
+
# system output has more lines than gold standard
|
436 |
+
if ((! defined $line_g) && (defined $line_s))
|
437 |
+
{
|
438 |
+
printf STDERR "line mismatch, line %d:\n", $line_num ;
|
439 |
+
printf STDERR " gold: past end of file\n" ;
|
440 |
+
printf STDERR " sys : %s", $line_s ;
|
441 |
+
exit(1) ;
|
442 |
+
}
|
443 |
+
|
444 |
+
# end of file reached for both
|
445 |
+
if ((! defined $line_g) && (! defined $line_s))
|
446 |
+
{
|
447 |
+
return (1) ;
|
448 |
+
}
|
449 |
+
|
450 |
+
# one contains end of sentence but other one does not
|
451 |
+
if (($line_g =~ /^\s+$/) != ($line_s =~ /^\s+$/))
|
452 |
+
{
|
453 |
+
printf STDERR "line mismatch, line %d:\n", $line_num ;
|
454 |
+
printf STDERR " gold: %s", $line_g ;
|
455 |
+
printf STDERR " sys : %s", $line_s ;
|
456 |
+
exit(1) ;
|
457 |
+
}
|
458 |
+
|
459 |
+
# end of sentence reached
|
460 |
+
if ($line_g =~ /^\s+$/)
|
461 |
+
{
|
462 |
+
return(0) ;
|
463 |
+
}
|
464 |
+
|
465 |
+
# now both lines contain information
|
466 |
+
|
467 |
+
if ($new_sent)
|
468 |
+
{
|
469 |
+
$new_sent = 0 ;
|
470 |
+
}
|
471 |
+
|
472 |
+
# 'official' column names
|
473 |
+
# options.output = ['id','form','lemma','cpostag','postag',
|
474 |
+
# 'feats','head','deprel','phead','pdeprel']
|
475 |
+
|
476 |
+
@fields_g{'word', 'pos', 'head', 'dep'} = (split (/\s+/, $line_g))[1, 3, 6, 7] ;
|
477 |
+
|
478 |
+
push @{$sent_gold}, { %fields_g } ;
|
479 |
+
|
480 |
+
@fields_s{'word', 'pos', 'head', 'dep'} = (split (/\s+/, $line_s))[1, 3, 6, 7] ;
|
481 |
+
|
482 |
+
if (($fields_g{word} ne $fields_s{word})
|
483 |
+
||
|
484 |
+
($fields_g{pos} ne $fields_s{pos}))
|
485 |
+
{
|
486 |
+
printf STDERR "Word/pos mismatch, line %d:\n", $line_num ;
|
487 |
+
printf STDERR " gold: %s", $line_g ;
|
488 |
+
printf STDERR " sys : %s", $line_s ;
|
489 |
+
exit(1) ;
|
490 |
+
}
|
491 |
+
|
492 |
+
push @{$sent_sys}, { %fields_s } ;
|
493 |
+
|
494 |
+
} # main reading loop
|
495 |
+
|
496 |
+
} # read_sent
|
497 |
+
|
498 |
+
################################################################################
|
499 |
+
### main ###
|
500 |
+
################################################################################
|
501 |
+
|
502 |
+
our ($opt_g, $opt_s, $opt_o, $opt_h, $opt_v, $opt_q, $opt_p, $opt_b) ;
|
503 |
+
|
504 |
+
my ($sent_num, $eof, $word_num, @err_sent) ;
|
505 |
+
my (@sent_gold, @sent_sys, @starts) ;
|
506 |
+
my ($word, $pos, $wp, $head_g, $dep_g, $head_s, $dep_s) ;
|
507 |
+
my (%counts, $err_head, $err_dep, $con, $con1, $con_pos, $con_pos1, $thresh) ;
|
508 |
+
my ($head_err, $dep_err, @cur_err, %err_counts, $err_counter, $err_desc) ;
|
509 |
+
my ($loc_con, %loc_con_err_counts, %err_desc) ;
|
510 |
+
my ($head_aft_bef_g, $head_aft_bef_s, $head_aft_bef) ;
|
511 |
+
my ($con_bef, $con_aft, $con_bef_2, $con_aft_2, @bits, @e_bits, @v_con, @v_con_pos) ;
|
512 |
+
my ($con_pos_bef, $con_pos_aft, $con_pos_bef_2, $con_pos_aft_2) ;
|
513 |
+
my ($max_word_len, $max_pos_len, $max_con_len, $max_con_pos_len) ;
|
514 |
+
my ($max_word_spec_len, $max_con_bef_len, $max_con_aft_len) ;
|
515 |
+
my (%freq_err, $err) ;
|
516 |
+
|
517 |
+
my ($i, $j, $i_w, $l, $n_args) ;
|
518 |
+
my ($w_2, $w_1, $w1, $w2) ;
|
519 |
+
my ($wp_2, $wp_1, $wp1, $wp2) ;
|
520 |
+
my ($p_2, $p_1, $p1, $p2) ;
|
521 |
+
|
522 |
+
my ($short_output) ;
|
523 |
+
my ($score_on_punct) ;
|
524 |
+
$counts{punct} = 0; # initialize
|
525 |
+
|
526 |
+
getopts("g:o:s:qvhpb") ;
|
527 |
+
|
528 |
+
if (defined $opt_v)
|
529 |
+
{
|
530 |
+
my $id = '$Id: eval.pl,v 1.9 2006/05/09 20:30:01 yuval Exp $';
|
531 |
+
my @parts = split ' ',$id;
|
532 |
+
print "Version $parts[2]\n";
|
533 |
+
exit(0);
|
534 |
+
}
|
535 |
+
|
536 |
+
if ((defined $opt_h) || ((! defined $opt_g) && (! defined $opt_s)))
|
537 |
+
{
|
538 |
+
die $usage ;
|
539 |
+
}
|
540 |
+
|
541 |
+
if (! defined $opt_g)
|
542 |
+
{
|
543 |
+
die "Gold standard file (-g) missing\n" ;
|
544 |
+
}
|
545 |
+
|
546 |
+
if (! defined $opt_s)
|
547 |
+
{
|
548 |
+
die "System output file (-s) missing\n" ;
|
549 |
+
}
|
550 |
+
|
551 |
+
if (! defined $opt_o)
|
552 |
+
{
|
553 |
+
$opt_o = '-' ;
|
554 |
+
}
|
555 |
+
|
556 |
+
if (defined $opt_q)
|
557 |
+
{
|
558 |
+
$short_output = 1 ;
|
559 |
+
} else {
|
560 |
+
$short_output = 0 ;
|
561 |
+
}
|
562 |
+
|
563 |
+
if (defined $opt_p)
|
564 |
+
{
|
565 |
+
$score_on_punct = 1 ;
|
566 |
+
} else {
|
567 |
+
$score_on_punct = 0 ;
|
568 |
+
}
|
569 |
+
|
570 |
+
$line_num = 0 ;
|
571 |
+
$sent_num = 0 ;
|
572 |
+
$eof = 0 ;
|
573 |
+
|
574 |
+
@err_sent = () ;
|
575 |
+
@starts = () ;
|
576 |
+
|
577 |
+
%{$err_sent[0]} = () ;
|
578 |
+
|
579 |
+
$max_pos_len = length('CPOS') ;
|
580 |
+
|
581 |
+
################################################################################
|
582 |
+
### reading input ###
|
583 |
+
################################################################################
|
584 |
+
|
585 |
+
open (GOLD, "<$opt_g") || die "Could not open gold standard file $opt_g\n" ;
|
586 |
+
open (SYS, "<$opt_s") || die "Could not open system output file $opt_s\n" ;
|
587 |
+
open (OUT, ">$opt_o") || die "Could not open output file $opt_o\n" ;
|
588 |
+
|
589 |
+
|
590 |
+
if (defined $opt_b) { # produce output similar to evalb
|
591 |
+
print OUT " Sent. Attachment Correct Scoring \n";
|
592 |
+
print OUT " ID Tokens - Unlab. Lab. HEAD HEAD+DEPREL tokens - - - -\n";
|
593 |
+
print OUT " ============================================================================\n";
|
594 |
+
}
|
595 |
+
|
596 |
+
|
597 |
+
while (! $eof)
|
598 |
+
{ # main reading loop
|
599 |
+
|
600 |
+
$starts[$sent_num] = $line_num+1 ;
|
601 |
+
$eof = read_sent(\@sent_gold, \@sent_sys) ;
|
602 |
+
|
603 |
+
$sent_num++ ;
|
604 |
+
|
605 |
+
%{$err_sent[$sent_num]} = () ;
|
606 |
+
$word_num = scalar @sent_gold ;
|
607 |
+
|
608 |
+
# for accuracy per sentence
|
609 |
+
my %sent_counts = ( tot => 0,
|
610 |
+
err_any => 0,
|
611 |
+
err_head => 0
|
612 |
+
);
|
613 |
+
|
614 |
+
# printf "$sent_num $word_num\n" ;
|
615 |
+
|
616 |
+
my @frames_g = ('** '); # the initial frame for the virtual root
|
617 |
+
my @frames_s = ('** '); # the initial frame for the virtual root
|
618 |
+
foreach $i_w (0 .. $word_num-1)
|
619 |
+
{ # loop on words
|
620 |
+
push @frames_g, ''; # initialize
|
621 |
+
push @frames_s, ''; # initialize
|
622 |
+
}
|
623 |
+
|
624 |
+
foreach $i_w (0 .. $word_num-1)
|
625 |
+
{ # loop on words
|
626 |
+
|
627 |
+
($word, $pos, $head_g, $dep_g)
|
628 |
+
= @{$sent_gold[$i_w]}{'word', 'pos', 'head', 'dep'} ;
|
629 |
+
$wp = $word.' / '.$pos ;
|
630 |
+
|
631 |
+
# printf "%d: %s %s %s %s\n", $i_w, $word, $pos, $head_g, $dep_g ;
|
632 |
+
|
633 |
+
if ((! $score_on_punct) && is_uni_punct($word))
|
634 |
+
{
|
635 |
+
$counts{punct}++ ;
|
636 |
+
# ignore punctuations
|
637 |
+
next ;
|
638 |
+
}
|
639 |
+
|
640 |
+
if (length($pos) > $max_pos_len)
|
641 |
+
{
|
642 |
+
$max_pos_len = length($pos) ;
|
643 |
+
}
|
644 |
+
|
645 |
+
($head_s, $dep_s) = @{$sent_sys[$i_w]}{'head', 'dep'} ;
|
646 |
+
|
647 |
+
$counts{tot}++ ;
|
648 |
+
$counts{word}{$wp}{tot}++ ;
|
649 |
+
$counts{pos}{$pos}{tot}++ ;
|
650 |
+
$counts{head}{$head_g-$i_w-1}{tot}++ ;
|
651 |
+
|
652 |
+
# for frame confusions
|
653 |
+
# add child to frame of parent
|
654 |
+
$frames_g[$head_g] .= "$dep_g ";
|
655 |
+
$frames_s[$head_s] .= "$dep_s ";
|
656 |
+
# add to frame of token itself
|
657 |
+
$frames_g[$i_w+1] .= "*$dep_g* "; # $i_w+1 because $i_w starts counting at zero
|
658 |
+
$frames_s[$i_w+1] .= "*$dep_g* ";
|
659 |
+
|
660 |
+
# for precision and recall of DEPREL
|
661 |
+
$counts{dep}{$dep_g}{tot}++ ; # counts for gold standard deprels
|
662 |
+
$counts{dep2}{$dep_g}{$dep_s}++ ; # counts for confusions
|
663 |
+
$counts{dep_s}{$dep_s}{tot}++ ; # counts for system deprels
|
664 |
+
$counts{all_dep}{$dep_g} = 1 ; # list of all deprels that occur ...
|
665 |
+
$counts{all_dep}{$dep_s} = 1 ; # ... in either gold or system output
|
666 |
+
|
667 |
+
# for precision and recall of HEAD direction
|
668 |
+
my $dir_g;
|
669 |
+
if ($head_g == 0) {
|
670 |
+
$dir_g = 'to_root';
|
671 |
+
} elsif ($head_g < $i_w+1) { # $i_w+1 because $i_w starts counting at zero
|
672 |
+
# also below
|
673 |
+
$dir_g = 'left';
|
674 |
+
} elsif ($head_g > $i_w+1) {
|
675 |
+
$dir_g = 'right';
|
676 |
+
} else {
|
677 |
+
# token links to itself; should never happen in correct gold standard
|
678 |
+
$dir_g = 'self';
|
679 |
+
}
|
680 |
+
my $dir_s;
|
681 |
+
if ($head_s == 0) {
|
682 |
+
$dir_s = 'to_root';
|
683 |
+
} elsif ($head_s < $i_w+1) {
|
684 |
+
$dir_s = 'left';
|
685 |
+
} elsif ($head_s > $i_w+1) {
|
686 |
+
$dir_s = 'right';
|
687 |
+
} else {
|
688 |
+
# token links to itself; should not happen in good system
|
689 |
+
# (but not forbidden in shared task)
|
690 |
+
$dir_s = 'self';
|
691 |
+
}
|
692 |
+
$counts{dir_g}{$dir_g}{tot}++ ; # counts for gold standard head direction
|
693 |
+
$counts{dir2}{$dir_g}{$dir_s}++ ; # counts for confusions
|
694 |
+
$counts{dir_s}{$dir_s}{tot}++ ; # counts for system head direction
|
695 |
+
|
696 |
+
# for precision and recall of HEAD distance
|
697 |
+
my $dist_g;
|
698 |
+
if ($head_g == 0) {
|
699 |
+
$dist_g = 'to_root';
|
700 |
+
} elsif ( abs($head_g - ($i_w+1)) <= 1 ) {
|
701 |
+
$dist_g = '1'; # includes the 'self' cases
|
702 |
+
} elsif ( abs($head_g - ($i_w+1)) <= 2 ) {
|
703 |
+
$dist_g = '2';
|
704 |
+
} elsif ( abs($head_g - ($i_w+1)) <= 6 ) {
|
705 |
+
$dist_g = '3-6';
|
706 |
+
} else {
|
707 |
+
$dist_g = '7-...';
|
708 |
+
}
|
709 |
+
my $dist_s;
|
710 |
+
if ($head_s == 0) {
|
711 |
+
$dist_s = 'to_root';
|
712 |
+
} elsif ( abs($head_s - ($i_w+1)) <= 1 ) {
|
713 |
+
$dist_s = '1'; # includes the 'self' cases
|
714 |
+
} elsif ( abs($head_s - ($i_w+1)) <= 2 ) {
|
715 |
+
$dist_s = '2';
|
716 |
+
} elsif ( abs($head_s - ($i_w+1)) <= 6 ) {
|
717 |
+
$dist_s = '3-6';
|
718 |
+
} else {
|
719 |
+
$dist_s = '7-...';
|
720 |
+
}
|
721 |
+
$counts{dist_g}{$dist_g}{tot}++ ; # counts for gold standard head distance
|
722 |
+
$counts{dist2}{$dist_g}{$dist_s}++ ; # counts for confusions
|
723 |
+
$counts{dist_s}{$dist_s}{tot}++ ; # counts for system head distance
|
724 |
+
|
725 |
+
|
726 |
+
$err_head = ($head_g ne $head_s) ; # error in head
|
727 |
+
$err_dep = ($dep_g ne $dep_s) ; # error in deprel
|
728 |
+
|
729 |
+
$head_err = '-' ;
|
730 |
+
$dep_err = '-' ;
|
731 |
+
|
732 |
+
# for accuracy per sentence
|
733 |
+
$sent_counts{tot}++ ;
|
734 |
+
if ($err_dep || $err_head) {
|
735 |
+
$sent_counts{err_any}++ ;
|
736 |
+
}
|
737 |
+
if ($err_head) {
|
738 |
+
$sent_counts{err_head}++ ;
|
739 |
+
}
|
740 |
+
|
741 |
+
# total counts and counts for CPOS involved in errors
|
742 |
+
|
743 |
+
if ($head_g eq '0')
|
744 |
+
{
|
745 |
+
$head_aft_bef_g = '0' ;
|
746 |
+
}
|
747 |
+
elsif ($head_g eq $i_w+1)
|
748 |
+
{
|
749 |
+
$head_aft_bef_g = 'e' ;
|
750 |
+
}
|
751 |
+
else
|
752 |
+
{
|
753 |
+
$head_aft_bef_g = ($head_g <= $i_w+1 ? 'b' : 'a') ;
|
754 |
+
}
|
755 |
+
|
756 |
+
if ($head_s eq '0')
|
757 |
+
{
|
758 |
+
$head_aft_bef_s = '0' ;
|
759 |
+
}
|
760 |
+
elsif ($head_s eq $i_w+1)
|
761 |
+
{
|
762 |
+
$head_aft_bef_s = 'e' ;
|
763 |
+
}
|
764 |
+
else
|
765 |
+
{
|
766 |
+
$head_aft_bef_s = ($head_s <= $i_w+1 ? 'b' : 'a') ;
|
767 |
+
}
|
768 |
+
|
769 |
+
$head_aft_bef = $head_aft_bef_g.$head_aft_bef_s ;
|
770 |
+
|
771 |
+
if ($err_head)
|
772 |
+
{
|
773 |
+
if ($head_aft_bef_s eq '0')
|
774 |
+
{
|
775 |
+
$head_err = 0 ;
|
776 |
+
}
|
777 |
+
else
|
778 |
+
{
|
779 |
+
$head_err = $head_s-$head_g ;
|
780 |
+
}
|
781 |
+
|
782 |
+
$err_sent[$sent_num]{head}++ ;
|
783 |
+
$counts{err_head}{tot}++ ;
|
784 |
+
$counts{err_head}{$head_err}++ ;
|
785 |
+
|
786 |
+
$counts{word}{err_head}{$wp}++ ;
|
787 |
+
$counts{pos}{$pos}{err_head}{tot}++ ;
|
788 |
+
$counts{pos}{$pos}{err_head}{$head_err}++ ;
|
789 |
+
}
|
790 |
+
|
791 |
+
if ($err_dep)
|
792 |
+
{
|
793 |
+
$dep_err = $dep_g.'->'.$dep_s ;
|
794 |
+
$err_sent[$sent_num]{dep}++ ;
|
795 |
+
$counts{err_dep}{tot}++ ;
|
796 |
+
$counts{err_dep}{$dep_err}++ ;
|
797 |
+
|
798 |
+
$counts{word}{err_dep}{$wp}++ ;
|
799 |
+
$counts{pos}{$pos}{err_dep}{tot}++ ;
|
800 |
+
$counts{pos}{$pos}{err_dep}{$dep_err}++ ;
|
801 |
+
|
802 |
+
if ($err_head)
|
803 |
+
{
|
804 |
+
$counts{err_both}++ ;
|
805 |
+
$counts{pos}{$pos}{err_both}++ ;
|
806 |
+
}
|
807 |
+
}
|
808 |
+
|
809 |
+
### DEPREL + ATTACHMENT
|
810 |
+
if ((!$err_dep) && ($err_head)) {
|
811 |
+
$counts{err_head_corr_dep}{tot}++ ;
|
812 |
+
$counts{err_head_corr_dep}{$dep_s}++ ;
|
813 |
+
}
|
814 |
+
### DEPREL + ATTACHMENT
|
815 |
+
|
816 |
+
# counts for words involved in errors
|
817 |
+
|
818 |
+
if (! ($err_head || $err_dep))
|
819 |
+
{
|
820 |
+
next ;
|
821 |
+
}
|
822 |
+
|
823 |
+
$err_sent[$sent_num]{word}++ ;
|
824 |
+
$counts{err_any}++ ;
|
825 |
+
$counts{word}{err_any}{$wp}++ ;
|
826 |
+
$counts{pos}{$pos}{err_any}++ ;
|
827 |
+
|
828 |
+
($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) = get_context(\@sent_gold, $i_w) ;
|
829 |
+
|
830 |
+
if ($w_2 ne $START)
|
831 |
+
{
|
832 |
+
$wp_2 = $w_2.' / '.$p_2 ;
|
833 |
+
}
|
834 |
+
else
|
835 |
+
{
|
836 |
+
$wp_2 = $w_2 ;
|
837 |
+
}
|
838 |
+
|
839 |
+
if ($w_1 ne $START)
|
840 |
+
{
|
841 |
+
$wp_1 = $w_1.' / '.$p_1 ;
|
842 |
+
}
|
843 |
+
else
|
844 |
+
{
|
845 |
+
$wp_1 = $w_1 ;
|
846 |
+
}
|
847 |
+
|
848 |
+
if ($w1 ne $END)
|
849 |
+
{
|
850 |
+
$wp1 = $w1.' / '.$p1 ;
|
851 |
+
}
|
852 |
+
else
|
853 |
+
{
|
854 |
+
$wp1 = $w1 ;
|
855 |
+
}
|
856 |
+
|
857 |
+
if ($w2 ne $END)
|
858 |
+
{
|
859 |
+
$wp2 = $w2.' / '.$p2 ;
|
860 |
+
}
|
861 |
+
else
|
862 |
+
{
|
863 |
+
$wp2 = $w2 ;
|
864 |
+
}
|
865 |
+
|
866 |
+
$con_bef = $wp_1 ;
|
867 |
+
$con_bef_2 = $wp_2.' + '.$wp_1 ;
|
868 |
+
$con_aft = $wp1 ;
|
869 |
+
$con_aft_2 = $wp1.' + '.$wp2 ;
|
870 |
+
|
871 |
+
$con_pos_bef = $p_1 ;
|
872 |
+
$con_pos_bef_2 = $p_2.'+'.$p_1 ;
|
873 |
+
$con_pos_aft = $p1 ;
|
874 |
+
$con_pos_aft_2 = $p1.'+'.$p2 ;
|
875 |
+
|
876 |
+
if ($w_1 ne $START)
|
877 |
+
{
|
878 |
+
# do not count '.S' as a word context
|
879 |
+
$counts{con_bef_2}{tot}{$con_bef_2}++ ;
|
880 |
+
$counts{con_bef_2}{err_head}{$con_bef_2} += $err_head ;
|
881 |
+
$counts{con_bef_2}{err_dep}{$con_bef_2} += $err_dep ;
|
882 |
+
$counts{con_bef}{tot}{$con_bef}++ ;
|
883 |
+
$counts{con_bef}{err_head}{$con_bef} += $err_head ;
|
884 |
+
$counts{con_bef}{err_dep}{$con_bef} += $err_dep ;
|
885 |
+
}
|
886 |
+
|
887 |
+
if ($w1 ne $END)
|
888 |
+
{
|
889 |
+
# do not count '.E' as a word context
|
890 |
+
$counts{con_aft_2}{tot}{$con_aft_2}++ ;
|
891 |
+
$counts{con_aft_2}{err_head}{$con_aft_2} += $err_head ;
|
892 |
+
$counts{con_aft_2}{err_dep}{$con_aft_2} += $err_dep ;
|
893 |
+
$counts{con_aft}{tot}{$con_aft}++ ;
|
894 |
+
$counts{con_aft}{err_head}{$con_aft} += $err_head ;
|
895 |
+
$counts{con_aft}{err_dep}{$con_aft} += $err_dep ;
|
896 |
+
}
|
897 |
+
|
898 |
+
$counts{con_pos_bef_2}{tot}{$con_pos_bef_2}++ ;
|
899 |
+
$counts{con_pos_bef_2}{err_head}{$con_pos_bef_2} += $err_head ;
|
900 |
+
$counts{con_pos_bef_2}{err_dep}{$con_pos_bef_2} += $err_dep ;
|
901 |
+
$counts{con_pos_bef}{tot}{$con_pos_bef}++ ;
|
902 |
+
$counts{con_pos_bef}{err_head}{$con_pos_bef} += $err_head ;
|
903 |
+
$counts{con_pos_bef}{err_dep}{$con_pos_bef} += $err_dep ;
|
904 |
+
|
905 |
+
$counts{con_pos_aft_2}{tot}{$con_pos_aft_2}++ ;
|
906 |
+
$counts{con_pos_aft_2}{err_head}{$con_pos_aft_2} += $err_head ;
|
907 |
+
$counts{con_pos_aft_2}{err_dep}{$con_pos_aft_2} += $err_dep ;
|
908 |
+
$counts{con_pos_aft}{tot}{$con_pos_aft}++ ;
|
909 |
+
$counts{con_pos_aft}{err_head}{$con_pos_aft} += $err_head ;
|
910 |
+
$counts{con_pos_aft}{err_dep}{$con_pos_aft} += $err_dep ;
|
911 |
+
|
912 |
+
$err = $head_err.$sep.$head_aft_bef.$sep.$dep_err ;
|
913 |
+
$freq_err{$err}++ ;
|
914 |
+
|
915 |
+
} # loop on words
|
916 |
+
|
917 |
+
foreach $i_w (0 .. $word_num) # including one for the virtual root
|
918 |
+
{ # loop on words
|
919 |
+
if ($frames_g[$i_w] ne $frames_s[$i_w]) {
|
920 |
+
$counts{frame2}{"$frames_g[$i_w]/ $frames_s[$i_w]"}++ ;
|
921 |
+
}
|
922 |
+
}
|
923 |
+
|
924 |
+
if (defined $opt_b) { # produce output similar to evalb
|
925 |
+
if ($word_num > 0) {
|
926 |
+
my ($unlabeled,$labeled) = ('NaN', 'NaN');
|
927 |
+
if ($sent_counts{tot} > 0) { # there are scoring tokens
|
928 |
+
$unlabeled = 100-$sent_counts{err_head}*100.0/$sent_counts{tot};
|
929 |
+
$labeled = 100-$sent_counts{err_any} *100.0/$sent_counts{tot};
|
930 |
+
}
|
931 |
+
printf OUT " %4d %4d 0 %6.2f %6.2f %4d %4d %4d 0 0 0 0\n",
|
932 |
+
$sent_num, $word_num,
|
933 |
+
$unlabeled, $labeled,
|
934 |
+
$sent_counts{tot}-$sent_counts{err_head},
|
935 |
+
$sent_counts{tot}-$sent_counts{err_any},
|
936 |
+
$sent_counts{tot},;
|
937 |
+
}
|
938 |
+
}
|
939 |
+
|
940 |
+
} # main reading loop
|
941 |
+
|
942 |
+
################################################################################
|
943 |
+
### printing output ###
|
944 |
+
################################################################################
|
945 |
+
|
946 |
+
if (defined $opt_b) { # produce output similar to evalb
|
947 |
+
print OUT "\n\n";
|
948 |
+
}
|
949 |
+
printf OUT " Labeled attachment score: %d / %d * 100 = %.2f %%\n",
|
950 |
+
$counts{tot}-$counts{err_any}, $counts{tot}, 100-$counts{err_any}*100.0/$counts{tot} ;
|
951 |
+
printf OUT " Unlabeled attachment score: %d / %d * 100 = %.2f %%\n",
|
952 |
+
$counts{tot}-$counts{err_head}{tot}, $counts{tot}, 100-$counts{err_head}{tot}*100.0/$counts{tot} ;
|
953 |
+
printf OUT " Label accuracy score: %d / %d * 100 = %.2f %%\n",
|
954 |
+
$counts{tot}-$counts{err_dep}{tot}, $counts{tot}, 100-$counts{err_dep}{tot}*100.0/$counts{tot} ;
|
955 |
+
|
956 |
+
if ($short_output)
|
957 |
+
{
|
958 |
+
exit(0) ;
|
959 |
+
}
|
960 |
+
printf OUT "\n %s\n\n", '=' x 80 ;
|
961 |
+
printf OUT " Evaluation of the results in %s\n vs. gold standard %s:\n\n", $opt_s, $opt_g ;
|
962 |
+
|
963 |
+
printf OUT " Legend: '%s' - the beginning of a sentence, '%s' - the end of a sentence\n\n", $START, $END ;
|
964 |
+
|
965 |
+
printf OUT " Number of non-scoring tokens: $counts{punct}\n\n";
|
966 |
+
|
967 |
+
printf OUT " The overall accuracy and its distribution over CPOSTAGs\n\n" ;
|
968 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
969 |
+
|
970 |
+
printf OUT " %-10s | %-5s | %-5s | %% | %-5s | %% | %-5s | %%\n",
|
971 |
+
'Accuracy', 'words', 'right', 'right', 'both' ;
|
972 |
+
printf OUT " %-10s | %-5s | %-5s | | %-5s | | %-5s |\n",
|
973 |
+
' ', ' ', 'head', ' dep', 'right' ;
|
974 |
+
|
975 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
976 |
+
|
977 |
+
printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
|
978 |
+
'total', $counts{tot},
|
979 |
+
$counts{tot}-$counts{err_head}{tot}, 100-$counts{err_head}{tot}*100.0/$counts{tot},
|
980 |
+
$counts{tot}-$counts{err_dep}{tot}, 100-$counts{err_dep}{tot}*100.0/$counts{tot},
|
981 |
+
$counts{tot}-$counts{err_any}, 100-$counts{err_any}*100.0/$counts{tot} ;
|
982 |
+
|
983 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
984 |
+
|
985 |
+
foreach $pos (sort {$counts{pos}{$b}{tot} <=> $counts{pos}{$a}{tot}} keys %{$counts{pos}})
|
986 |
+
{
|
987 |
+
if (! defined($counts{pos}{$pos}{err_head}{tot}))
|
988 |
+
{
|
989 |
+
$counts{pos}{$pos}{err_head}{tot} = 0 ;
|
990 |
+
}
|
991 |
+
if (! defined($counts{pos}{$pos}{err_dep}{tot}))
|
992 |
+
{
|
993 |
+
$counts{pos}{$pos}{err_dep}{tot} = 0 ;
|
994 |
+
}
|
995 |
+
if (! defined($counts{pos}{$pos}{err_any}))
|
996 |
+
{
|
997 |
+
$counts{pos}{$pos}{err_any} = 0 ;
|
998 |
+
}
|
999 |
+
|
1000 |
+
printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
|
1001 |
+
$pos, $counts{pos}{$pos}{tot},
|
1002 |
+
$counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_head}{tot}, 100-$counts{pos}{$pos}{err_head}{tot}*100.0/$counts{pos}{$pos}{tot},
|
1003 |
+
$counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_dep}{tot}, 100-$counts{pos}{$pos}{err_dep}{tot}*100.0/$counts{pos}{$pos}{tot},
|
1004 |
+
$counts{pos}{$pos}{tot}-$counts{pos}{$pos}{err_any}, 100-$counts{pos}{$pos}{err_any}*100.0/$counts{pos}{$pos}{tot} ;
|
1005 |
+
}
|
1006 |
+
|
1007 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
1008 |
+
|
1009 |
+
printf OUT "\n\n" ;
|
1010 |
+
|
1011 |
+
printf OUT " The overall error rate and its distribution over CPOSTAGs\n\n" ;
|
1012 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
1013 |
+
|
1014 |
+
printf OUT " %-10s | %-5s | %-5s | %% | %-5s | %% | %-5s | %%\n",
|
1015 |
+
'Error', 'words', 'head', ' dep', 'both' ;
|
1016 |
+
printf OUT " %-10s | %-5s | %-5s | | %-5s | | %-5s |\n",
|
1017 |
+
|
1018 |
+
'Rate', ' ', 'err', ' err', 'wrong' ;
|
1019 |
+
|
1020 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
1021 |
+
|
1022 |
+
printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
|
1023 |
+
'total', $counts{tot},
|
1024 |
+
$counts{err_head}{tot}, $counts{err_head}{tot}*100.0/$counts{tot},
|
1025 |
+
$counts{err_dep}{tot}, $counts{err_dep}{tot}*100.0/$counts{tot},
|
1026 |
+
$counts{err_both}, $counts{err_both}*100.0/$counts{tot} ;
|
1027 |
+
|
1028 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
1029 |
+
|
1030 |
+
foreach $pos (sort {$counts{pos}{$b}{tot} <=> $counts{pos}{$a}{tot}} keys %{$counts{pos}})
|
1031 |
+
{
|
1032 |
+
if (! defined($counts{pos}{$pos}{err_both}))
|
1033 |
+
{
|
1034 |
+
$counts{pos}{$pos}{err_both} = 0 ;
|
1035 |
+
}
|
1036 |
+
|
1037 |
+
printf OUT " %-10s | %5d | %5d | %3.0f%% | %5d | %3.0f%% | %5d | %3.0f%%\n",
|
1038 |
+
$pos, $counts{pos}{$pos}{tot},
|
1039 |
+
$counts{pos}{$pos}{err_head}{tot}, $counts{pos}{$pos}{err_head}{tot}*100.0/$counts{pos}{$pos}{tot},
|
1040 |
+
$counts{pos}{$pos}{err_dep}{tot}, $counts{pos}{$pos}{err_dep}{tot}*100.0/$counts{pos}{$pos}{tot},
|
1041 |
+
$counts{pos}{$pos}{err_both}, $counts{pos}{$pos}{err_both}*100.0/$counts{pos}{$pos}{tot} ;
|
1042 |
+
|
1043 |
+
}
|
1044 |
+
|
1045 |
+
printf OUT "%s\n", " -----------+-------+-------+------+-------+------+-------+-------" ;
|
1046 |
+
|
1047 |
+
### added by Sabine Buchholz
|
1048 |
+
printf OUT "\n\n";
|
1049 |
+
printf OUT " Precision and recall of DEPREL\n\n";
|
1050 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
1051 |
+
printf OUT " deprel | gold | correct | system | recall (%%) | precision (%%) \n";
|
1052 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
1053 |
+
foreach my $dep (sort keys %{$counts{all_dep}}) {
|
1054 |
+
# initialize
|
1055 |
+
my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
|
1056 |
+
|
1057 |
+
if (defined($counts{dep2}{$dep}{$dep})) {
|
1058 |
+
$tot_corr = $counts{dep2}{$dep}{$dep};
|
1059 |
+
}
|
1060 |
+
if (defined($counts{dep}{$dep}{tot})) {
|
1061 |
+
$tot_g = $counts{dep}{$dep}{tot};
|
1062 |
+
$rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
|
1063 |
+
}
|
1064 |
+
if (defined($counts{dep_s}{$dep}{tot})) {
|
1065 |
+
$tot_s = $counts{dep_s}{$dep}{tot};
|
1066 |
+
$prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
|
1067 |
+
}
|
1068 |
+
printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
|
1069 |
+
$dep, $tot_g, $tot_corr, $tot_s, $rec, $prec;
|
1070 |
+
}
|
1071 |
+
|
1072 |
+
### DEPREL + ATTACHMENT:
|
1073 |
+
### Same as Sabine's DEPREL apart from $tot_corr calculation
|
1074 |
+
printf OUT "\n\n";
|
1075 |
+
printf OUT " Precision and recall of DEPREL + ATTACHMENT\n\n";
|
1076 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
1077 |
+
printf OUT " deprel | gold | correct | system | recall (%%) | precision (%%) \n";
|
1078 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
1079 |
+
foreach my $dep (sort keys %{$counts{all_dep}}) {
|
1080 |
+
# initialize
|
1081 |
+
my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
|
1082 |
+
|
1083 |
+
if (defined($counts{dep2}{$dep}{$dep})) {
|
1084 |
+
if (defined($counts{err_head_corr_dep}{$dep})) {
|
1085 |
+
$tot_corr = $counts{dep2}{$dep}{$dep} - $counts{err_head_corr_dep}{$dep};
|
1086 |
+
} else {
|
1087 |
+
$tot_corr = $counts{dep2}{$dep}{$dep};
|
1088 |
+
}
|
1089 |
+
}
|
1090 |
+
if (defined($counts{dep}{$dep}{tot})) {
|
1091 |
+
$tot_g = $counts{dep}{$dep}{tot};
|
1092 |
+
$rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
|
1093 |
+
}
|
1094 |
+
if (defined($counts{dep_s}{$dep}{tot})) {
|
1095 |
+
$tot_s = $counts{dep_s}{$dep}{tot};
|
1096 |
+
$prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
|
1097 |
+
}
|
1098 |
+
printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
|
1099 |
+
$dep, $tot_g, $tot_corr, $tot_s, $rec, $prec;
|
1100 |
+
}
|
1101 |
+
### DEPREL + ATTACHMENT
|
1102 |
+
|
1103 |
+
printf OUT "\n\n";
|
1104 |
+
printf OUT " Precision and recall of binned HEAD direction\n\n";
|
1105 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
1106 |
+
printf OUT " direction | gold | correct | system | recall (%%) | precision (%%) \n";
|
1107 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
1108 |
+
foreach my $dir ('to_root', 'left', 'right', 'self') {
|
1109 |
+
# initialize
|
1110 |
+
my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
|
1111 |
+
|
1112 |
+
if (defined($counts{dir2}{$dir}{$dir})) {
|
1113 |
+
$tot_corr = $counts{dir2}{$dir}{$dir};
|
1114 |
+
}
|
1115 |
+
if (defined($counts{dir_g}{$dir}{tot})) {
|
1116 |
+
$tot_g = $counts{dir_g}{$dir}{tot};
|
1117 |
+
$rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
|
1118 |
+
}
|
1119 |
+
if (defined($counts{dir_s}{$dir}{tot})) {
|
1120 |
+
$tot_s = $counts{dir_s}{$dir}{tot};
|
1121 |
+
$prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
|
1122 |
+
}
|
1123 |
+
printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
|
1124 |
+
$dir, $tot_g, $tot_corr, $tot_s, $rec, $prec;
|
1125 |
+
}
|
1126 |
+
|
1127 |
+
printf OUT "\n\n";
|
1128 |
+
printf OUT " Precision and recall of binned HEAD distance\n\n";
|
1129 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
1130 |
+
printf OUT " distance | gold | correct | system | recall (%%) | precision (%%) \n";
|
1131 |
+
printf OUT " ----------------+------+---------+--------+------------+---------------\n";
|
1132 |
+
foreach my $dist ('to_root', '1', '2', '3-6', '7-...') {
|
1133 |
+
# initialize
|
1134 |
+
my ($tot_corr, $tot_g, $tot_s, $prec, $rec) = (0, 0, 0, 'NaN', 'NaN');
|
1135 |
+
|
1136 |
+
if (defined($counts{dist2}{$dist}{$dist})) {
|
1137 |
+
$tot_corr = $counts{dist2}{$dist}{$dist};
|
1138 |
+
}
|
1139 |
+
if (defined($counts{dist_g}{$dist}{tot})) {
|
1140 |
+
$tot_g = $counts{dist_g}{$dist}{tot};
|
1141 |
+
$rec = sprintf("%.2f",$tot_corr / $tot_g * 100);
|
1142 |
+
}
|
1143 |
+
if (defined($counts{dist_s}{$dist}{tot})) {
|
1144 |
+
$tot_s = $counts{dist_s}{$dist}{tot};
|
1145 |
+
$prec = sprintf("%.2f",$tot_corr / $tot_s * 100);
|
1146 |
+
}
|
1147 |
+
printf OUT " %-15s | %4d | %7d | %6d | %10s | %13s\n",
|
1148 |
+
$dist, $tot_g, $tot_corr, $tot_s, $rec, $prec;
|
1149 |
+
}
|
1150 |
+
|
1151 |
+
printf OUT "\n\n";
|
1152 |
+
printf OUT " Frame confusions (gold versus system; *...* marks the head token)\n\n";
|
1153 |
+
foreach my $frame (sort {$counts{frame2}{$b} <=> $counts{frame2}{$a}} keys %{$counts{frame2}})
|
1154 |
+
{
|
1155 |
+
if ($counts{frame2}{$frame} >= 5) # (make 5 a changeable threshold later)
|
1156 |
+
{
|
1157 |
+
printf OUT " %3d %s\n", $counts{frame2}{$frame}, $frame;
|
1158 |
+
}
|
1159 |
+
}
|
1160 |
+
### end of: added by Sabine Buchholz
|
1161 |
+
|
1162 |
+
|
1163 |
+
#
|
1164 |
+
# Leave only the 5 words mostly involved in errors
|
1165 |
+
#
|
1166 |
+
|
1167 |
+
|
1168 |
+
$thresh = (sort {$b <=> $a} values %{$counts{word}{err_any}})[4] ;
|
1169 |
+
|
1170 |
+
# ensure enough space for title
|
1171 |
+
$max_word_len = length('word') ;
|
1172 |
+
|
1173 |
+
foreach $word (keys %{$counts{word}{err_any}})
|
1174 |
+
{
|
1175 |
+
if ($counts{word}{err_any}{$word} < $thresh)
|
1176 |
+
{
|
1177 |
+
delete $counts{word}{err_any}{$word} ;
|
1178 |
+
next ;
|
1179 |
+
}
|
1180 |
+
|
1181 |
+
$l = uni_len($word) ;
|
1182 |
+
if ($l > $max_word_len)
|
1183 |
+
{
|
1184 |
+
$max_word_len = $l ;
|
1185 |
+
}
|
1186 |
+
}
|
1187 |
+
|
1188 |
+
# filter a case when the difference between the error counts
|
1189 |
+
# for 2-word and 1-word contexts is small
|
1190 |
+
# (leave the 2-word context)
|
1191 |
+
|
1192 |
+
foreach $con (keys %{$counts{con_aft_2}{tot}})
|
1193 |
+
{
|
1194 |
+
($w1) = split(/\+/, $con) ;
|
1195 |
+
|
1196 |
+
if (defined $counts{con_aft}{tot}{$w1} &&
|
1197 |
+
$counts{con_aft}{tot}{$w1}-$counts{con_aft_2}{tot}{$con} <= 1)
|
1198 |
+
{
|
1199 |
+
delete $counts{con_aft}{tot}{$w1} ;
|
1200 |
+
}
|
1201 |
+
}
|
1202 |
+
|
1203 |
+
foreach $con (keys %{$counts{con_bef_2}{tot}})
|
1204 |
+
{
|
1205 |
+
($w_2, $w_1) = split(/\+/, $con) ;
|
1206 |
+
|
1207 |
+
if (defined $counts{con_bef}{tot}{$w_1} &&
|
1208 |
+
$counts{con_bef}{tot}{$w_1}-$counts{con_bef_2}{tot}{$con} <= 1)
|
1209 |
+
{
|
1210 |
+
delete $counts{con_bef}{tot}{$w_1} ;
|
1211 |
+
}
|
1212 |
+
}
|
1213 |
+
|
1214 |
+
foreach $con_pos (keys %{$counts{con_pos_aft_2}{tot}})
|
1215 |
+
{
|
1216 |
+
($p1) = split(/\+/, $con_pos) ;
|
1217 |
+
|
1218 |
+
if (defined($counts{con_pos_aft}{tot}{$p1}) &&
|
1219 |
+
$counts{con_pos_aft}{tot}{$p1}-$counts{con_pos_aft_2}{tot}{$con_pos} <= 1)
|
1220 |
+
{
|
1221 |
+
delete $counts{con_pos_aft}{tot}{$p1} ;
|
1222 |
+
}
|
1223 |
+
}
|
1224 |
+
|
1225 |
+
foreach $con_pos (keys %{$counts{con_pos_bef_2}{tot}})
|
1226 |
+
{
|
1227 |
+
($p_2, $p_1) = split(/\+/, $con_pos) ;
|
1228 |
+
|
1229 |
+
if (defined($counts{con_pos_bef}{tot}{$p_1}) &&
|
1230 |
+
$counts{con_pos_bef}{tot}{$p_1}-$counts{con_pos_bef_2}{tot}{$con_pos} <= 1)
|
1231 |
+
{
|
1232 |
+
delete $counts{con_pos_bef}{tot}{$p_1} ;
|
1233 |
+
}
|
1234 |
+
}
|
1235 |
+
|
1236 |
+
# for each context type, take the three contexts most involved in errors
|
1237 |
+
|
1238 |
+
$max_con_len = 0 ;
|
1239 |
+
|
1240 |
+
filter_context_counts($counts{con_bef_2}{tot}, $con_err_num, \$max_con_len) ;
|
1241 |
+
|
1242 |
+
filter_context_counts($counts{con_bef}{tot}, $con_err_num, \$max_con_len) ;
|
1243 |
+
|
1244 |
+
filter_context_counts($counts{con_aft}{tot}, $con_err_num, \$max_con_len) ;
|
1245 |
+
|
1246 |
+
filter_context_counts($counts{con_aft_2}{tot}, $con_err_num, \$max_con_len) ;
|
1247 |
+
|
1248 |
+
# for each CPOS context type, take the three CPOS contexts most involved in errors
|
1249 |
+
|
1250 |
+
$max_con_pos_len = 0 ;
|
1251 |
+
|
1252 |
+
$thresh = (sort {$b <=> $a} values %{$counts{con_pos_bef_2}{tot}})[$con_err_num-1] ;
|
1253 |
+
|
1254 |
+
foreach $con_pos (keys %{$counts{con_pos_bef_2}{tot}})
|
1255 |
+
{
|
1256 |
+
if ($counts{con_pos_bef_2}{tot}{$con_pos} < $thresh)
|
1257 |
+
{
|
1258 |
+
delete $counts{con_pos_bef_2}{tot}{$con_pos} ;
|
1259 |
+
next ;
|
1260 |
+
}
|
1261 |
+
if (length($con_pos) > $max_con_pos_len)
|
1262 |
+
{
|
1263 |
+
$max_con_pos_len = length($con_pos) ;
|
1264 |
+
}
|
1265 |
+
}
|
1266 |
+
|
1267 |
+
$thresh = (sort {$b <=> $a} values %{$counts{con_pos_bef}{tot}})[$con_err_num-1] ;
|
1268 |
+
|
1269 |
+
foreach $con_pos (keys %{$counts{con_pos_bef}{tot}})
|
1270 |
+
{
|
1271 |
+
if ($counts{con_pos_bef}{tot}{$con_pos} < $thresh)
|
1272 |
+
{
|
1273 |
+
delete $counts{con_pos_bef}{tot}{$con_pos} ;
|
1274 |
+
next ;
|
1275 |
+
}
|
1276 |
+
if (length($con_pos) > $max_con_pos_len)
|
1277 |
+
{
|
1278 |
+
$max_con_pos_len = length($con_pos) ;
|
1279 |
+
}
|
1280 |
+
}
|
1281 |
+
|
1282 |
+
$thresh = (sort {$b <=> $a} values %{$counts{con_pos_aft}{tot}})[$con_err_num-1] ;
|
1283 |
+
|
1284 |
+
foreach $con_pos (keys %{$counts{con_pos_aft}{tot}})
|
1285 |
+
{
|
1286 |
+
if ($counts{con_pos_aft}{tot}{$con_pos} < $thresh)
|
1287 |
+
{
|
1288 |
+
delete $counts{con_pos_aft}{tot}{$con_pos} ;
|
1289 |
+
next ;
|
1290 |
+
}
|
1291 |
+
if (length($con_pos) > $max_con_pos_len)
|
1292 |
+
{
|
1293 |
+
$max_con_pos_len = length($con_pos) ;
|
1294 |
+
}
|
1295 |
+
}
|
1296 |
+
|
1297 |
+
$thresh = (sort {$b <=> $a} values %{$counts{con_pos_aft_2}{tot}})[$con_err_num-1] ;
|
1298 |
+
|
1299 |
+
foreach $con_pos (keys %{$counts{con_pos_aft_2}{tot}})
|
1300 |
+
{
|
1301 |
+
if ($counts{con_pos_aft_2}{tot}{$con_pos} < $thresh)
|
1302 |
+
{
|
1303 |
+
delete $counts{con_pos_aft_2}{tot}{$con_pos} ;
|
1304 |
+
next ;
|
1305 |
+
}
|
1306 |
+
if (length($con_pos) > $max_con_pos_len)
|
1307 |
+
{
|
1308 |
+
$max_con_pos_len = length($con_pos) ;
|
1309 |
+
}
|
1310 |
+
}
|
1311 |
+
|
1312 |
+
# printing
|
1313 |
+
|
1314 |
+
# ------------- focus words
|
1315 |
+
|
1316 |
+
printf OUT "\n\n" ;
|
1317 |
+
printf OUT " %d focus words where most of the errors occur:\n\n", scalar keys %{$counts{word}{err_any}} ;
|
1318 |
+
|
1319 |
+
printf OUT " %-*s | %-4s | %-4s | %-4s | %-4s\n", $max_word_len, ' ', 'any', 'head', 'dep', 'both' ;
|
1320 |
+
printf OUT " %s-+------+------+------+------\n", '-' x $max_word_len;
|
1321 |
+
|
1322 |
+
foreach $word (sort {$counts{word}{err_any}{$b} <=> $counts{word}{err_any}{$a}} keys %{$counts{word}{err_any}})
|
1323 |
+
{
|
1324 |
+
if (!defined($counts{word}{err_head}{$word}))
|
1325 |
+
{
|
1326 |
+
$counts{word}{err_head}{$word} = 0 ;
|
1327 |
+
}
|
1328 |
+
if (! defined($counts{word}{err_dep}{$word}))
|
1329 |
+
{
|
1330 |
+
$counts{word}{err_dep}{$word} = 0 ;
|
1331 |
+
}
|
1332 |
+
if (! defined($counts{word}{err_any}{$word}))
|
1333 |
+
{
|
1334 |
+
$counts{word}{err_any}{$word} = 0;
|
1335 |
+
}
|
1336 |
+
printf OUT " %-*s | %4d | %4d | %4d | %4d\n",
|
1337 |
+
$max_word_len+length($word)-uni_len($word), $word, $counts{word}{err_any}{$word},
|
1338 |
+
$counts{word}{err_head}{$word},
|
1339 |
+
$counts{word}{err_dep}{$word},
|
1340 |
+
$counts{word}{err_dep}{$word}+$counts{word}{err_head}{$word}-$counts{word}{err_any}{$word} ;
|
1341 |
+
}
|
1342 |
+
|
1343 |
+
printf OUT " %s-+------+------+------+------\n", '-' x $max_word_len;
|
1344 |
+
|
1345 |
+
# ------------- contexts
|
1346 |
+
|
1347 |
+
printf OUT "\n\n" ;
|
1348 |
+
|
1349 |
+
printf OUT " one-token preceeding contexts where most of the errors occur:\n\n" ;
|
1350 |
+
|
1351 |
+
print_context($counts{con_bef}, $counts{con_pos_bef}, $max_con_len, $max_con_pos_len) ;
|
1352 |
+
|
1353 |
+
printf OUT " two-token preceeding contexts where most of the errors occur:\n\n" ;
|
1354 |
+
|
1355 |
+
print_context($counts{con_bef_2}, $counts{con_pos_bef_2}, $max_con_len, $max_con_pos_len) ;
|
1356 |
+
|
1357 |
+
printf OUT " one-token following contexts where most of the errors occur:\n\n" ;
|
1358 |
+
|
1359 |
+
print_context($counts{con_aft}, $counts{con_pos_aft}, $max_con_len, $max_con_pos_len) ;
|
1360 |
+
|
1361 |
+
printf OUT " two-token following contexts where most of the errors occur:\n\n" ;
|
1362 |
+
|
1363 |
+
print_context($counts{con_aft_2}, $counts{con_pos_aft_2}, $max_con_len, $max_con_pos_len) ;
|
1364 |
+
|
1365 |
+
# ------------- Sentences
|
1366 |
+
|
1367 |
+
printf OUT " Sentence with the highest number of word errors:\n" ;
|
1368 |
+
$i = (sort { (defined($err_sent[$b]{word}) && $err_sent[$b]{word})
|
1369 |
+
<=> (defined($err_sent[$a]{word}) && $err_sent[$a]{word}) } 1 .. $sent_num)[0] ;
|
1370 |
+
printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
|
1371 |
+
printf OUT "%d head errors, %d dependency errors, %d word errors\n",
|
1372 |
+
$err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
|
1373 |
+
|
1374 |
+
printf OUT "\n\n" ;
|
1375 |
+
|
1376 |
+
printf OUT " Sentence with the highest number of head errors:\n" ;
|
1377 |
+
$i = (sort { (defined($err_sent[$b]{head}) && $err_sent[$b]{head})
|
1378 |
+
<=> (defined($err_sent[$a]{head}) && $err_sent[$a]{head}) } 1 .. $sent_num)[0] ;
|
1379 |
+
printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
|
1380 |
+
printf OUT "%d head errors, %d dependency errors, %d word errors\n",
|
1381 |
+
$err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
|
1382 |
+
|
1383 |
+
printf OUT "\n\n" ;
|
1384 |
+
|
1385 |
+
printf OUT " Sentence with the highest number of dependency errors:\n" ;
|
1386 |
+
$i = (sort { (defined($err_sent[$b]{dep}) && $err_sent[$b]{dep})
|
1387 |
+
<=> (defined($err_sent[$a]{dep}) && $err_sent[$a]{dep}) } 1 .. $sent_num)[0] ;
|
1388 |
+
printf OUT " Sentence %d line %d, ", $i, $starts[$i-1] ;
|
1389 |
+
printf OUT "%d head errors, %d dependency errors, %d word errors\n",
|
1390 |
+
$err_sent[$i]{head}, $err_sent[$i]{dep}, $err_sent[$i]{word} ;
|
1391 |
+
|
1392 |
+
#
|
1393 |
+
# Second pass, collect statistics of the frequent errors
|
1394 |
+
#
|
1395 |
+
|
1396 |
+
# filter the errors, leave the most frequent $freq_err_num errors
|
1397 |
+
|
1398 |
+
$i = 0 ;
|
1399 |
+
|
1400 |
+
$thresh = (sort {$b <=> $a} values %freq_err)[$freq_err_num-1] ;
|
1401 |
+
|
1402 |
+
foreach $err (keys %freq_err)
|
1403 |
+
{
|
1404 |
+
if ($freq_err{$err} < $thresh)
|
1405 |
+
{
|
1406 |
+
delete $freq_err{$err} ;
|
1407 |
+
}
|
1408 |
+
}
|
1409 |
+
|
1410 |
+
# in case there are several errors with the threshold count
|
1411 |
+
|
1412 |
+
$freq_err_num = scalar keys %freq_err ;
|
1413 |
+
|
1414 |
+
%err_counts = () ;
|
1415 |
+
|
1416 |
+
$eof = 0 ;
|
1417 |
+
|
1418 |
+
seek (GOLD, 0, 0) ;
|
1419 |
+
seek (SYS, 0, 0) ;
|
1420 |
+
|
1421 |
+
while (! $eof)
|
1422 |
+
{ # second reading loop
|
1423 |
+
|
1424 |
+
$eof = read_sent(\@sent_gold, \@sent_sys) ;
|
1425 |
+
$sent_num++ ;
|
1426 |
+
|
1427 |
+
$word_num = scalar @sent_gold ;
|
1428 |
+
|
1429 |
+
# printf "$sent_num $word_num\n" ;
|
1430 |
+
|
1431 |
+
foreach $i_w (0 .. $word_num-1)
|
1432 |
+
{ # loop on words
|
1433 |
+
($word, $pos, $head_g, $dep_g)
|
1434 |
+
= @{$sent_gold[$i_w]}{'word', 'pos', 'head', 'dep'} ;
|
1435 |
+
|
1436 |
+
# printf "%d: %s %s %s %s\n", $i_w, $word, $pos, $head_g, $dep_g ;
|
1437 |
+
|
1438 |
+
if ((! $score_on_punct) && is_uni_punct($word))
|
1439 |
+
{
|
1440 |
+
# ignore punctuations
|
1441 |
+
next ;
|
1442 |
+
}
|
1443 |
+
|
1444 |
+
($head_s, $dep_s) = @{$sent_sys[$i_w]}{'head', 'dep'} ;
|
1445 |
+
|
1446 |
+
$err_head = ($head_g ne $head_s) ;
|
1447 |
+
$err_dep = ($dep_g ne $dep_s) ;
|
1448 |
+
|
1449 |
+
$head_err = '-' ;
|
1450 |
+
$dep_err = '-' ;
|
1451 |
+
|
1452 |
+
if ($head_g eq '0')
|
1453 |
+
{
|
1454 |
+
$head_aft_bef_g = '0' ;
|
1455 |
+
}
|
1456 |
+
elsif ($head_g eq $i_w+1)
|
1457 |
+
{
|
1458 |
+
$head_aft_bef_g = 'e' ;
|
1459 |
+
}
|
1460 |
+
else
|
1461 |
+
{
|
1462 |
+
$head_aft_bef_g = ($head_g <= $i_w+1 ? 'b' : 'a') ;
|
1463 |
+
}
|
1464 |
+
|
1465 |
+
if ($head_s eq '0')
|
1466 |
+
{
|
1467 |
+
$head_aft_bef_s = '0' ;
|
1468 |
+
}
|
1469 |
+
elsif ($head_s eq $i_w+1)
|
1470 |
+
{
|
1471 |
+
$head_aft_bef_s = 'e' ;
|
1472 |
+
}
|
1473 |
+
else
|
1474 |
+
{
|
1475 |
+
$head_aft_bef_s = ($head_s <= $i_w+1 ? 'b' : 'a') ;
|
1476 |
+
}
|
1477 |
+
|
1478 |
+
$head_aft_bef = $head_aft_bef_g.$head_aft_bef_s ;
|
1479 |
+
|
1480 |
+
if ($err_head)
|
1481 |
+
{
|
1482 |
+
if ($head_aft_bef_s eq '0')
|
1483 |
+
{
|
1484 |
+
$head_err = 0 ;
|
1485 |
+
}
|
1486 |
+
else
|
1487 |
+
{
|
1488 |
+
$head_err = $head_s-$head_g ;
|
1489 |
+
}
|
1490 |
+
}
|
1491 |
+
|
1492 |
+
if ($err_dep)
|
1493 |
+
{
|
1494 |
+
$dep_err = $dep_g.'->'.$dep_s ;
|
1495 |
+
}
|
1496 |
+
|
1497 |
+
if (! ($err_head || $err_dep))
|
1498 |
+
{
|
1499 |
+
next ;
|
1500 |
+
}
|
1501 |
+
|
1502 |
+
# handle only the most frequent errors
|
1503 |
+
|
1504 |
+
$err = $head_err.$sep.$head_aft_bef.$sep.$dep_err ;
|
1505 |
+
|
1506 |
+
if (! exists $freq_err{$err})
|
1507 |
+
{
|
1508 |
+
next ;
|
1509 |
+
}
|
1510 |
+
|
1511 |
+
($w_2, $w_1, $w1, $w2, $p_2, $p_1, $p1, $p2) = get_context(\@sent_gold, $i_w) ;
|
1512 |
+
|
1513 |
+
$con_bef = $w_1 ;
|
1514 |
+
$con_bef_2 = $w_2.' + '.$w_1 ;
|
1515 |
+
$con_aft = $w1 ;
|
1516 |
+
$con_aft_2 = $w1.' + '.$w2 ;
|
1517 |
+
|
1518 |
+
$con_pos_bef = $p_1 ;
|
1519 |
+
$con_pos_bef_2 = $p_2.'+'.$p_1 ;
|
1520 |
+
$con_pos_aft = $p1 ;
|
1521 |
+
$con_pos_aft_2 = $p1.'+'.$p2 ;
|
1522 |
+
|
1523 |
+
@cur_err = ($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) ;
|
1524 |
+
|
1525 |
+
# printf "# %-25s %-15s %-10s %-25s %-3s %-30s\n",
|
1526 |
+
# $con_bef, $word, $pos, $con_aft, $head_err, $dep_err ;
|
1527 |
+
|
1528 |
+
@bits = (0, 0, 0, 0, 0, 0) ;
|
1529 |
+
$j = 0 ;
|
1530 |
+
|
1531 |
+
while ($j == 0)
|
1532 |
+
{
|
1533 |
+
for ($i = 0; $i <= $#bits; $i++)
|
1534 |
+
{
|
1535 |
+
if ($bits[$i] == 0)
|
1536 |
+
{
|
1537 |
+
$bits[$i] = 1 ;
|
1538 |
+
$j = 0 ;
|
1539 |
+
last ;
|
1540 |
+
}
|
1541 |
+
else
|
1542 |
+
{
|
1543 |
+
$bits[$i] = 0 ;
|
1544 |
+
$j = 1 ;
|
1545 |
+
}
|
1546 |
+
}
|
1547 |
+
|
1548 |
+
@e_bits = @cur_err ;
|
1549 |
+
|
1550 |
+
for ($i = 0; $i <= $#bits; $i++)
|
1551 |
+
{
|
1552 |
+
if (! $bits[$i])
|
1553 |
+
{
|
1554 |
+
$e_bits[$i] = '*' ;
|
1555 |
+
}
|
1556 |
+
}
|
1557 |
+
|
1558 |
+
# include also the last case which is the most general
|
1559 |
+
# (wildcards for everything)
|
1560 |
+
$err_counts{$err}{join($sep, @e_bits)}++ ;
|
1561 |
+
|
1562 |
+
}
|
1563 |
+
|
1564 |
+
} # loop on words
|
1565 |
+
} # second reading loop
|
1566 |
+
|
1567 |
+
printf OUT "\n\n" ;
|
1568 |
+
printf OUT " Specific errors, %d most frequent errors:", $freq_err_num ;
|
1569 |
+
printf OUT "\n %s\n", '=' x 41 ;
|
1570 |
+
|
1571 |
+
|
1572 |
+
# deleting local contexts which are too general
|
1573 |
+
|
1574 |
+
foreach $err (keys %err_counts)
|
1575 |
+
{
|
1576 |
+
foreach $loc_con (sort {$err_counts{$err}{$b} <=> $err_counts{$err}{$a}}
|
1577 |
+
keys %{$err_counts{$err}})
|
1578 |
+
{
|
1579 |
+
@cur_err = split(/\Q$sep\E/, $loc_con) ;
|
1580 |
+
|
1581 |
+
# In this loop, one or two elements of the local context are
|
1582 |
+
# replaced with '*' to make it more general. If the entry for
|
1583 |
+
# the general context has the same count it is removed.
|
1584 |
+
|
1585 |
+
foreach $i (0 .. $#cur_err)
|
1586 |
+
{
|
1587 |
+
$w1 = $cur_err[$i] ;
|
1588 |
+
if ($cur_err[$i] eq '*')
|
1589 |
+
{
|
1590 |
+
next ;
|
1591 |
+
}
|
1592 |
+
$cur_err[$i] = '*' ;
|
1593 |
+
$con1 = join($sep, @cur_err) ;
|
1594 |
+
if ( defined($err_counts{$err}{$con1}) && defined($err_counts{$err}{$loc_con})
|
1595 |
+
&& ($err_counts{$err}{$con1} == $err_counts{$err}{$loc_con}))
|
1596 |
+
{
|
1597 |
+
delete $err_counts{$err}{$con1} ;
|
1598 |
+
}
|
1599 |
+
for ($j = $i+1; $j <=$#cur_err; $j++)
|
1600 |
+
{
|
1601 |
+
if ($cur_err[$j] eq '*')
|
1602 |
+
{
|
1603 |
+
next ;
|
1604 |
+
}
|
1605 |
+
$w2 = $cur_err[$j] ;
|
1606 |
+
$cur_err[$j] = '*' ;
|
1607 |
+
$con1 = join($sep, @cur_err) ;
|
1608 |
+
if ( defined($err_counts{$err}{$con1}) && defined($err_counts{$err}{$loc_con})
|
1609 |
+
&& ($err_counts{$err}{$con1} == $err_counts{$err}{$loc_con}))
|
1610 |
+
{
|
1611 |
+
delete $err_counts{$err}{$con1} ;
|
1612 |
+
}
|
1613 |
+
$cur_err[$j] = $w2 ;
|
1614 |
+
}
|
1615 |
+
$cur_err[$i] = $w1 ;
|
1616 |
+
}
|
1617 |
+
}
|
1618 |
+
}
|
1619 |
+
|
1620 |
+
# Leaving only the topmost local contexts for each error
|
1621 |
+
|
1622 |
+
foreach $err (keys %err_counts)
|
1623 |
+
{
|
1624 |
+
$thresh = (sort {$b <=> $a} values %{$err_counts{$err}})[$spec_err_loc_con-1] || 0 ;
|
1625 |
+
|
1626 |
+
# of the threshold is too low, take the 2nd highest count
|
1627 |
+
# (the highest may be the total which is the generic case
|
1628 |
+
# and not relevant for printing)
|
1629 |
+
|
1630 |
+
if ($thresh < 5)
|
1631 |
+
{
|
1632 |
+
$thresh = (sort {$b <=> $a} values %{$err_counts{$err}})[1] ;
|
1633 |
+
}
|
1634 |
+
|
1635 |
+
foreach $loc_con (keys %{$err_counts{$err}})
|
1636 |
+
{
|
1637 |
+
if ($err_counts{$err}{$loc_con} < $thresh)
|
1638 |
+
{
|
1639 |
+
delete $err_counts{$err}{$loc_con} ;
|
1640 |
+
}
|
1641 |
+
else
|
1642 |
+
{
|
1643 |
+
if ($loc_con ne join($sep, ('*', '*', '*', '*', '*', '*')))
|
1644 |
+
{
|
1645 |
+
$loc_con_err_counts{$loc_con}{$err} = $err_counts{$err}{$loc_con} ;
|
1646 |
+
}
|
1647 |
+
}
|
1648 |
+
}
|
1649 |
+
}
|
1650 |
+
|
1651 |
+
# printing an error summary
|
1652 |
+
|
1653 |
+
# calculating the context field length
|
1654 |
+
|
1655 |
+
$max_word_spec_len= length('word') ;
|
1656 |
+
$max_con_aft_len = length('word') ;
|
1657 |
+
$max_con_bef_len = length('word') ;
|
1658 |
+
$max_con_pos_len = length('CPOS') ;
|
1659 |
+
|
1660 |
+
foreach $err (keys %err_counts)
|
1661 |
+
{
|
1662 |
+
foreach $loc_con (sort keys %{$err_counts{$err}})
|
1663 |
+
{
|
1664 |
+
($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
|
1665 |
+
split(/\Q$sep\E/, $loc_con) ;
|
1666 |
+
|
1667 |
+
$l = uni_len($word) ;
|
1668 |
+
if ($l > $max_word_spec_len)
|
1669 |
+
{
|
1670 |
+
$max_word_spec_len = $l ;
|
1671 |
+
}
|
1672 |
+
|
1673 |
+
$l = uni_len($con_bef) ;
|
1674 |
+
if ($l > $max_con_bef_len)
|
1675 |
+
{
|
1676 |
+
$max_con_bef_len = $l ;
|
1677 |
+
}
|
1678 |
+
|
1679 |
+
$l = uni_len($con_aft) ;
|
1680 |
+
if ($l > $max_con_aft_len)
|
1681 |
+
{
|
1682 |
+
$max_con_aft_len = $l ;
|
1683 |
+
}
|
1684 |
+
|
1685 |
+
if (length($con_pos_aft) > $max_con_pos_len)
|
1686 |
+
{
|
1687 |
+
$max_con_pos_len = length($con_pos_aft) ;
|
1688 |
+
}
|
1689 |
+
|
1690 |
+
if (length($con_pos_bef) > $max_con_pos_len)
|
1691 |
+
{
|
1692 |
+
$max_con_pos_len = length($con_pos_bef) ;
|
1693 |
+
}
|
1694 |
+
}
|
1695 |
+
}
|
1696 |
+
|
1697 |
+
$err_counter = 0 ;
|
1698 |
+
|
1699 |
+
foreach $err (sort {$freq_err{$b} <=> $freq_err{$a}} keys %freq_err)
|
1700 |
+
{
|
1701 |
+
|
1702 |
+
($head_err, $head_aft_bef, $dep_err) = split(/\Q$sep\E/, $err) ;
|
1703 |
+
|
1704 |
+
$err_counter++ ;
|
1705 |
+
$err_desc{$err} = sprintf("%2d. ", $err_counter).
|
1706 |
+
describe_err($head_err, $head_aft_bef, $dep_err) ;
|
1707 |
+
|
1708 |
+
# printf OUT " %-3s %-30s %d\n", $head_err, $dep_err, $freq_err{$err} ;
|
1709 |
+
printf OUT "\n" ;
|
1710 |
+
printf OUT " %s : %d times\n", $err_desc{$err}, $freq_err{$err} ;
|
1711 |
+
|
1712 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
|
1713 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
1714 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
1715 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
1716 |
+
|
1717 |
+
printf OUT " %-*s | %-*s | %-*s | %s\n",
|
1718 |
+
$max_con_pos_len+$max_con_bef_len+3, ' Before',
|
1719 |
+
$max_word_spec_len+$max_pos_len+3, ' Focus',
|
1720 |
+
$max_con_pos_len+$max_con_aft_len+3, ' After',
|
1721 |
+
'Count' ;
|
1722 |
+
|
1723 |
+
printf OUT " %-*s %-*s | %-*s %-*s | %-*s %-*s |\n",
|
1724 |
+
$max_con_pos_len, 'CPOS', $max_con_bef_len, 'word',
|
1725 |
+
$max_pos_len, 'CPOS', $max_word_spec_len, 'word',
|
1726 |
+
$max_con_pos_len, 'CPOS', $max_con_aft_len, 'word' ;
|
1727 |
+
|
1728 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
|
1729 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
1730 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
1731 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
1732 |
+
|
1733 |
+
foreach $loc_con (sort {$err_counts{$err}{$b} <=> $err_counts{$err}{$a}}
|
1734 |
+
keys %{$err_counts{$err}})
|
1735 |
+
{
|
1736 |
+
if ($loc_con eq join($sep, ('*', '*', '*', '*', '*', '*')))
|
1737 |
+
{
|
1738 |
+
next ;
|
1739 |
+
}
|
1740 |
+
|
1741 |
+
$con1 = $loc_con ;
|
1742 |
+
$con1 =~ s/\*/ /g ;
|
1743 |
+
|
1744 |
+
($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
|
1745 |
+
split(/\Q$sep\E/, $con1) ;
|
1746 |
+
|
1747 |
+
printf OUT " %-*s | %-*s | %-*s | %-*s | %-*s | %-*s | %3d\n",
|
1748 |
+
$max_con_pos_len, $con_pos_bef, $max_con_bef_len+length($con_bef)-uni_len($con_bef), $con_bef,
|
1749 |
+
$max_pos_len, $pos, $max_word_spec_len+length($word)-uni_len($word), $word,
|
1750 |
+
$max_con_pos_len, $con_pos_aft, $max_con_aft_len+length($con_aft)-uni_len($con_aft), $con_aft,
|
1751 |
+
$err_counts{$err}{$loc_con} ;
|
1752 |
+
}
|
1753 |
+
|
1754 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-+------\n",
|
1755 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
1756 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
1757 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
1758 |
+
|
1759 |
+
}
|
1760 |
+
|
1761 |
+
printf OUT "\n\n" ;
|
1762 |
+
printf OUT " Local contexts involved in several frequent errors:" ;
|
1763 |
+
printf OUT "\n %s\n", '=' x 51 ;
|
1764 |
+
printf OUT "\n\n" ;
|
1765 |
+
|
1766 |
+
foreach $loc_con (sort {scalar keys %{$loc_con_err_counts{$b}} <=>
|
1767 |
+
scalar keys %{$loc_con_err_counts{$a}}}
|
1768 |
+
keys %loc_con_err_counts)
|
1769 |
+
{
|
1770 |
+
|
1771 |
+
if (scalar keys %{$loc_con_err_counts{$loc_con}} == 1)
|
1772 |
+
{
|
1773 |
+
next ;
|
1774 |
+
}
|
1775 |
+
|
1776 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
|
1777 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
1778 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
1779 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
1780 |
+
|
1781 |
+
printf OUT " %-*s | %-*s | %-*s \n",
|
1782 |
+
$max_con_pos_len+$max_con_bef_len+3, ' Before',
|
1783 |
+
$max_word_spec_len+$max_pos_len+3, ' Focus',
|
1784 |
+
$max_con_pos_len+$max_con_aft_len+3, ' After' ;
|
1785 |
+
|
1786 |
+
printf OUT " %-*s %-*s | %-*s %-*s | %-*s %-*s \n",
|
1787 |
+
$max_con_pos_len, 'CPOS', $max_con_bef_len, 'word',
|
1788 |
+
$max_pos_len, 'CPOS', $max_word_spec_len, 'word',
|
1789 |
+
$max_con_pos_len, 'CPOS', $max_con_aft_len, 'word' ;
|
1790 |
+
|
1791 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
|
1792 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
1793 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
1794 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
1795 |
+
|
1796 |
+
$con1 = $loc_con ;
|
1797 |
+
$con1 =~ s/\*/ /g ;
|
1798 |
+
|
1799 |
+
($con_pos_bef, $con_bef, $word, $pos, $con_pos_aft, $con_aft) =
|
1800 |
+
split(/\Q$sep\E/, $con1) ;
|
1801 |
+
|
1802 |
+
printf OUT " %-*s | %-*s | %-*s | %-*s | %-*s | %-*s \n",
|
1803 |
+
$max_con_pos_len, $con_pos_bef, $max_con_bef_len+length($con_bef)-uni_len($con_bef), $con_bef,
|
1804 |
+
$max_pos_len, $pos, $max_word_spec_len+length($word)-uni_len($word), $word,
|
1805 |
+
$max_con_pos_len, $con_pos_aft, $max_con_aft_len+length($con_aft)-uni_len($con_aft), $con_aft ;
|
1806 |
+
|
1807 |
+
printf OUT " %s-+-%s-+-%s-+-%s-+-%s-+-%s-\n",
|
1808 |
+
'-' x $max_con_pos_len, '-' x $max_con_bef_len,
|
1809 |
+
'-' x $max_pos_len, '-' x $max_word_spec_len,
|
1810 |
+
'-' x $max_con_pos_len, '-' x $max_con_aft_len ;
|
1811 |
+
|
1812 |
+
foreach $err (sort {$loc_con_err_counts{$loc_con}{$b} <=>
|
1813 |
+
$loc_con_err_counts{$loc_con}{$a}}
|
1814 |
+
keys %{$loc_con_err_counts{$loc_con}})
|
1815 |
+
{
|
1816 |
+
printf OUT " %s : %d times\n", $err_desc{$err},
|
1817 |
+
$loc_con_err_counts{$loc_con}{$err} ;
|
1818 |
+
}
|
1819 |
+
|
1820 |
+
printf OUT "\n" ;
|
1821 |
+
}
|
1822 |
+
|
1823 |
+
close GOLD ;
|
1824 |
+
close SYS ;
|
1825 |
+
|
1826 |
+
close OUT ;
|
examples/test_original_dcst.sh
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
domain="san"
|
3 |
+
word_path="../Documents/DCST/data/multilingual_word_embeddings/cc.sanskrit.300.new.vec"
|
4 |
+
char_path="../Documents/DCST/data/multilingual_word_embeddings/hellwig_char_embedding.128"
|
5 |
+
pos_path='../Documents/DCST/data/multilingual_word_embeddings/pos_embedding_FT.100'
|
6 |
+
declare -i num_epochs=1
|
7 |
+
declare -i word_dim=300
|
8 |
+
declare -i set_num_training_samples=500
|
9 |
+
start_time=`date +%s`
|
10 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
11 |
+
# model_path="ud_parser_san_"$current_time
|
12 |
+
|
13 |
+
# ######## Changes ###########
|
14 |
+
# ## --use_pos is removed from all the models
|
15 |
+
# ## singletions are excluded.
|
16 |
+
# ## auxiliary task epoch change to 20
|
17 |
+
|
18 |
+
touch saved_models/log.txt
|
19 |
+
################################################################
|
20 |
+
echo "#################################################################"
|
21 |
+
echo "Currently base model in progress..."
|
22 |
+
echo "#################################################################"
|
23 |
+
python examples/GraphParser.py --dataset ud --domain $domain --rnn_mode LSTM \
|
24 |
+
--num_epochs $num_epochs --batch_size 16 --hidden_size 512 --arc_space 512 \
|
25 |
+
--arc_tag_space 128 --num_layers 3 --num_filters 100 --use_char \
|
26 |
+
--word_dim $word_dim --pos_dim 100 --initializer xavier --opt adam \
|
27 |
+
--learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 \
|
28 |
+
--epsilon 1e-6 --p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst \
|
29 |
+
--punct_set '.' '``' ':' ',' --word_embedding fasttext \
|
30 |
+
--set_num_training_samples $set_num_training_samples \
|
31 |
+
--word_path $word_path --char_path $char_path --pos_path $pos_path --pos_embedding random --char_embedding random --char_dim 100 \
|
32 |
+
--model_path saved_models/$model_path 2>&1 | tee saved_models/log.txt
|
33 |
+
mv saved_models/log.txt saved_models/$model_path/log.txt
|
34 |
+
|
35 |
+
|
36 |
+
## model_path="ud_parser_san_2020.03.03-12.24.41"
|
37 |
+
|
38 |
+
####### self-training ###########
|
39 |
+
## 'number_of_children' 'relative_pos_based' 'distance_from_the_root'
|
40 |
+
|
41 |
+
####### Multitask setup #########
|
42 |
+
## 'Multitask_label_predict' 'Multitask_case_predict' 'Multitask_POS_predict' 'Multitask_coarse_predict'
|
43 |
+
## 'predict_ma_tag_of_modifier' 'add_label' 'predict_case_of_modifier'
|
44 |
+
|
45 |
+
####### Other Tasks #############
|
46 |
+
## # 'add_head_ma' 'add_head_coarse_pos' 'predict_coarse_of_modifier'
|
47 |
+
|
48 |
+
for task in 'Multitask_label_predict' 'Multitask_case_predict' 'Multitask_POS_predict' 'number_of_children' 'relative_pos_based' 'distance_from_the_root' ; do
|
49 |
+
echo "#################################################################"
|
50 |
+
echo "Currently $task in progress..."
|
51 |
+
echo "#################################################################"
|
52 |
+
touch saved_models/log.txt
|
53 |
+
python examples/SequenceTagger.py --dataset ud --domain $domain --task $task \
|
54 |
+
--rnn_mode LSTM --num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
|
55 |
+
--tag_space 128 --num_layers 3 --num_filters 100 --use_char \
|
56 |
+
--pos_dim 100 --initializer xavier --opt adam --learning_rate 0.002 --decay_rate 0.5 \
|
57 |
+
--schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 --p_rnn 0.33 0.33 \
|
58 |
+
--p_in 0.33 --p_out 0.33 --punct_set '.' '``' ':' ',' \
|
59 |
+
--word_dim $word_dim --word_embedding fasttext --word_path $word_path \
|
60 |
+
--parser_path saved_models/$model_path/ \
|
61 |
+
--use_unlabeled_data --char_path $char_path --pos_path $pos_path --pos_embedding random --char_embedding random --char_dim 100 \
|
62 |
+
--model_path saved_models/$model_path/$task/ 2>&1 | tee saved_models/log.txt
|
63 |
+
mv saved_models/log.txt saved_models/$model_path/$task/log.txt
|
64 |
+
done
|
65 |
+
|
66 |
+
echo "#################################################################"
|
67 |
+
echo "Currently final model in progress..."
|
68 |
+
echo "#################################################################"
|
69 |
+
touch saved_models/log.txt
|
70 |
+
python examples/GraphParser.py --dataset ud --domain $domain --rnn_mode LSTM \
|
71 |
+
--num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
|
72 |
+
--arc_space 512 --arc_tag_space 128 --num_layers 3 --num_filters 100 --use_char \
|
73 |
+
--word_dim $word_dim --pos_dim 100 --initializer xavier --opt adam \
|
74 |
+
--learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 \
|
75 |
+
--p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst \
|
76 |
+
--punct_set '.' '``' ':' ',' --word_embedding fasttext \
|
77 |
+
--word_path $word_path --gating --num_gates 4 \
|
78 |
+
--char_path $char_path --pos_path $pos_path --pos_embedding random --char_embedding random --char_dim 100 \
|
79 |
+
--load_sequence_taggers_paths saved_models/$model_path/Multitask_case_predict/domain_$domain.pt \
|
80 |
+
saved_models/$model_path/Multitask_POS_predict/domain_$domain.pt \
|
81 |
+
saved_models/$model_path/Multitask_label_predict/domain_$domain.pt \
|
82 |
+
--model_path saved_models/$model_path/final_ensembled_multi/ 2>&1 | tee saved_models/log.txt
|
83 |
+
mv saved_models/log.txt saved_models/$model_path/final_ensembled_multi/log.txt
|
84 |
+
end_time=`date +%s`
|
85 |
+
echo execution time was `expr $end_time - $start_time` s.
|
86 |
+
|
87 |
+
|
88 |
+
echo "#################################################################"
|
89 |
+
echo "Currently final model in progress..."
|
90 |
+
echo "#################################################################"
|
91 |
+
touch saved_models/log.txt
|
92 |
+
python examples/GraphParser.py --dataset ud --domain $domain --rnn_mode LSTM \
|
93 |
+
--num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
|
94 |
+
--arc_space 512 --arc_tag_space 128 --num_layers 3 --num_filters 100 --use_char \
|
95 |
+
--word_dim $word_dim --pos_dim 100 --initializer xavier --opt adam \
|
96 |
+
--learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 \
|
97 |
+
--p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst \
|
98 |
+
--punct_set '.' '``' ':' ',' --word_embedding fasttext \
|
99 |
+
--word_path $word_path --gating --num_gates 7 \
|
100 |
+
--char_path $char_path --pos_path $pos_path --pos_embedding random --char_embedding random --char_dim 100 \
|
101 |
+
--load_sequence_taggers_paths saved_models/$model_path/Multitask_case_predict/domain_$domain.pt \
|
102 |
+
saved_models/$model_path/Multitask_POS_predict/domain_$domain.pt \
|
103 |
+
saved_models/$model_path/Multitask_label_predict/domain_$domain.pt \
|
104 |
+
saved_models/$model_path/number_of_children/domain_$domain.pt \
|
105 |
+
saved_models/$model_path/relative_pos_based/domain_$domain.pt \
|
106 |
+
saved_models/$model_path/distance_from_the_root/domain_$domain.pt \
|
107 |
+
--model_path saved_models/$model_path/final_ensembled_multi_self/ 2>&1 | tee saved_models/log.txt
|
108 |
+
mv saved_models/log.txt saved_models/$model_path/final_ensembled_multi_self/log.txt
|
109 |
+
end_time=`date +%s`
|
110 |
+
echo execution time was `expr $end_time - $start_time` s.
|
run_san_LCM.sh
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
domain="san"
|
3 |
+
|
4 |
+
word_path="./data/multilingual_word_embeddings/cc.sanskrit.300.new.vec"
|
5 |
+
declare -i num_epochs=100
|
6 |
+
declare -i word_dim=300
|
7 |
+
declare -i set_num_training_samples=500
|
8 |
+
start_time=`date +%s`
|
9 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
10 |
+
model_path="ud_parser_san_"$current_time
|
11 |
+
|
12 |
+
|
13 |
+
echo "#################################################################"
|
14 |
+
echo "Currently BiAFFINE model training in progress..."
|
15 |
+
echo "#################################################################"
|
16 |
+
python examples/GraphParser.py --dataset ud --domain $domain --rnn_mode LSTM \
|
17 |
+
--num_epochs $num_epochs --batch_size 16 --hidden_size 512 --arc_space 512 \
|
18 |
+
--arc_tag_space 128 --num_layers 2 --num_filters 100 --use_char \
|
19 |
+
--set_num_training_samples $set_num_training_samples \
|
20 |
+
--word_dim $word_dim --char_dim 100 --pos_dim 100 --initializer xavier --opt adam \
|
21 |
+
--learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 \
|
22 |
+
--epsilon 1e-6 --p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst \
|
23 |
+
--punct_set '.' '``' ':' ',' --word_embedding fasttext --char_embedding random --pos_embedding random --word_path $word_path \
|
24 |
+
--model_path saved_models/$model_path 2>&1 | tee saved_models/base_log.txt
|
25 |
+
|
26 |
+
mv saved_models/base_log.txt saved_models/$model_path/base_log.txt
|
27 |
+
|
28 |
+
##################################################################
|
29 |
+
## Pretrainig Step
|
30 |
+
## For DCST setting set tasks as : 'number_of_children' 'relative_pos_based' 'distance_from_the_root'
|
31 |
+
## For LCM set tasks as: 'Multitask_case_predict' 'Multitask_POS_predict' 'add_label'
|
32 |
+
for task in 'Multitask_case_predict' 'Multitask_POS_predict' 'add_label'; do
|
33 |
+
touch saved_models/$model_path/log.txt
|
34 |
+
echo "#################################################################"
|
35 |
+
echo "Currently $task in progress..."
|
36 |
+
echo "#################################################################"
|
37 |
+
python examples/SequenceTagger.py --dataset ud --domain $domain --task $task \
|
38 |
+
--rnn_mode LSTM --num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
|
39 |
+
--tag_space 128 --num_layers 2 --num_filters 100 --use_char --char_dim 100 \
|
40 |
+
--pos_dim 100 --initializer xavier --opt adam --learning_rate 0.002 --decay_rate 0.5 \
|
41 |
+
--schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 --p_rnn 0.33 0.33 \
|
42 |
+
--p_in 0.33 --p_out 0.33 --punct_set '.' '``' ':' ',' \
|
43 |
+
--word_dim $word_dim --word_embedding fasttext --word_path $word_path --pos_embedding random \
|
44 |
+
--parser_path saved_models/$model_path/ \
|
45 |
+
--use_unlabeled_data --use_labeled_data --char_embedding random \
|
46 |
+
--model_path saved_models/$model_path/$task/ 2>&1 | tee saved_models/$model_path/log.txt
|
47 |
+
mv saved_models/$model_path/log.txt saved_models/$model_path/$task/log.txt
|
48 |
+
done
|
49 |
+
|
50 |
+
# ########################################################################
|
51 |
+
|
52 |
+
echo "#################################################################"
|
53 |
+
echo "Final Parsing Model Integrated with Pretrained Encoders of Auxiliary tasks..."
|
54 |
+
echo "#################################################################"
|
55 |
+
touch saved_models/$model_path/log.txt
|
56 |
+
python examples/GraphParser.py --dataset ud --domain $domain --rnn_mode LSTM \
|
57 |
+
--num_epochs $num_epochs --batch_size 16 --hidden_size 512 \
|
58 |
+
--arc_space 512 --arc_tag_space 128 --num_layers 2 --num_filters 100 --use_char \
|
59 |
+
--word_dim $word_dim --char_dim 100 --pos_dim 100 --initializer xavier --opt adam \
|
60 |
+
--learning_rate 0.002 --decay_rate 0.5 --schedule 6 --clip 5.0 --gamma 0.0 --epsilon 1e-6 \
|
61 |
+
--p_rnn 0.33 0.33 --p_in 0.33 --p_out 0.33 --arc_decode mst --pos_embedding random \
|
62 |
+
--punct_set '.' '``' ':' ',' --word_embedding fasttext --char_embedding random --word_path $word_path \
|
63 |
+
--gating --num_gates 4 \
|
64 |
+
--set_num_training_samples $set_num_training_samples \
|
65 |
+
--load_sequence_taggers_paths saved_models/$model_path/add_label/domain_$domain.pt \
|
66 |
+
saved_models/$model_path/Multitask_case_predict/domain_$domain.pt \
|
67 |
+
saved_models/$model_path/Multitask_POS_predict/domain_$domain.pt \
|
68 |
+
--model_path saved_models/$model_path/final_ensembled_BiAFF_LCM 2>&1 | tee saved_models/$model_path/log.txt
|
69 |
+
mv saved_models/$model_path/log.txt saved_models/$model_path/final_ensembled_BiAFF_LCM/log.txt
|
70 |
+
|
71 |
+
end_time=`date +%s`
|
72 |
+
echo execution time was `expr $end_time - $start_time` s.
|
73 |
+
|
utils/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import io_
|
2 |
+
from . import nn
|
3 |
+
from . import load_word_embeddings
|
4 |
+
from . import nlinalg
|
5 |
+
from . import models
|
6 |
+
|
7 |
+
__version__ = "0.1.dev1"
|
utils/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (290 Bytes). View file
|
|
utils/__pycache__/load_word_embeddings.cpython-37.pyc
ADDED
Binary file (2.63 kB). View file
|
|
utils/io_/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .alphabet import *
|
2 |
+
from .instance import *
|
3 |
+
from .logger import *
|
4 |
+
from .writer import *
|
5 |
+
from . import prepare_data
|
utils/io_/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (228 Bytes). View file
|
|
utils/io_/__pycache__/alphabet.cpython-37.pyc
ADDED
Binary file (5.35 kB). View file
|
|
utils/io_/__pycache__/instance.cpython-37.pyc
ADDED
Binary file (1.09 kB). View file
|
|
utils/io_/__pycache__/logger.cpython-37.pyc
ADDED
Binary file (545 Bytes). View file
|
|
utils/io_/__pycache__/prepare_data.cpython-37.pyc
ADDED
Binary file (10.2 kB). View file
|
|
utils/io_/__pycache__/reader.cpython-37.pyc
ADDED
Binary file (2.37 kB). View file
|
|
utils/io_/__pycache__/rearrange_splits.cpython-37.pyc
ADDED
Binary file (2.79 kB). View file
|
|
utils/io_/__pycache__/seeds.cpython-37.pyc
ADDED
Binary file (358 Bytes). View file
|
|
utils/io_/__pycache__/write_extra_labels.cpython-37.pyc
ADDED
Binary file (36.3 kB). View file
|
|
utils/io_/__pycache__/writer.cpython-37.pyc
ADDED
Binary file (2.13 kB). View file
|
|
utils/io_/alphabet.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects.
|
3 |
+
"""
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from .logger import get_logger
|
7 |
+
|
8 |
+
class Alphabet(object):
|
9 |
+
def __init__(self, name, defualt_value=False, keep_growing=True, singleton=False):
|
10 |
+
self.__name = name
|
11 |
+
|
12 |
+
self.instance2index = {}
|
13 |
+
self.instances = []
|
14 |
+
self.default_value = defualt_value
|
15 |
+
self.offset = 1 if self.default_value else 0
|
16 |
+
self.keep_growing = keep_growing
|
17 |
+
self.singletons = set() if singleton else None
|
18 |
+
|
19 |
+
# Index 0 is occupied by default, all else following.
|
20 |
+
self.default_index = 0 if self.default_value else None
|
21 |
+
|
22 |
+
self.next_index = self.offset
|
23 |
+
|
24 |
+
self.logger = get_logger("Alphabet")
|
25 |
+
|
26 |
+
def add(self, instance):
|
27 |
+
if instance not in self.instance2index and instance != '<_UNK>':
|
28 |
+
self.instances.append(instance)
|
29 |
+
self.instance2index[instance] = self.next_index
|
30 |
+
self.next_index += 1
|
31 |
+
|
32 |
+
def add_singleton(self, id):
|
33 |
+
if self.singletons is None:
|
34 |
+
raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
|
35 |
+
else:
|
36 |
+
self.singletons.add(id)
|
37 |
+
|
38 |
+
def add_singletons(self, ids):
|
39 |
+
if self.singletons is None:
|
40 |
+
raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
|
41 |
+
else:
|
42 |
+
self.singletons.update(ids)
|
43 |
+
|
44 |
+
def is_singleton(self, id):
|
45 |
+
if self.singletons is None:
|
46 |
+
raise RuntimeError("Alphabet %s does not have singleton." % self.__name)
|
47 |
+
else:
|
48 |
+
return id in self.singletons
|
49 |
+
|
50 |
+
def get_index(self, instance):
|
51 |
+
try:
|
52 |
+
return self.instance2index[instance]
|
53 |
+
except KeyError:
|
54 |
+
if self.keep_growing:
|
55 |
+
index = self.next_index
|
56 |
+
self.add(instance)
|
57 |
+
return index
|
58 |
+
else:
|
59 |
+
if self.default_value:
|
60 |
+
return self.default_index
|
61 |
+
else:
|
62 |
+
raise KeyError("instance not found: %s" % instance)
|
63 |
+
|
64 |
+
def get_instance(self, index):
|
65 |
+
if self.default_value and index == self.default_index:
|
66 |
+
# First index is occupied by the wildcard element.
|
67 |
+
return "<_UNK>"
|
68 |
+
else:
|
69 |
+
try:
|
70 |
+
return self.instances[index - self.offset]
|
71 |
+
except IndexError:
|
72 |
+
raise IndexError("unknown index: %d" % index)
|
73 |
+
|
74 |
+
def size(self):
|
75 |
+
return len(self.instances) + self.offset
|
76 |
+
|
77 |
+
def singleton_size(self):
|
78 |
+
return len(self.singletons)
|
79 |
+
|
80 |
+
def items(self):
|
81 |
+
return self.instance2index.items()
|
82 |
+
|
83 |
+
def keys(self):
|
84 |
+
return self.instance2index.keys()
|
85 |
+
|
86 |
+
def values(self):
|
87 |
+
return self.instance2index.values()
|
88 |
+
|
89 |
+
def token_in_alphabet(self, token):
|
90 |
+
return token in set(self.instance2index.keys())
|
91 |
+
|
92 |
+
def enumerate_items(self, start):
|
93 |
+
if start < self.offset or start >= self.size():
|
94 |
+
raise IndexError("Enumerate is allowed between [%d : size of the alphabet)" % self.offset)
|
95 |
+
return zip(range(start, len(self.instances) + self.offset), self.instances[start - self.offset:])
|
96 |
+
|
97 |
+
def close(self):
|
98 |
+
self.keep_growing = False
|
99 |
+
|
100 |
+
def open(self):
|
101 |
+
self.keep_growing = True
|
102 |
+
|
103 |
+
def get_content(self):
|
104 |
+
if self.singletons is None:
|
105 |
+
return {"instance2index": self.instance2index, "instances": self.instances}
|
106 |
+
else:
|
107 |
+
return {"instance2index": self.instance2index, "instances": self.instances,
|
108 |
+
"singletions": list(self.singletons)}
|
109 |
+
|
110 |
+
def __from_json(self, data):
|
111 |
+
self.instances = data["instances"]
|
112 |
+
self.instance2index = data["instance2index"]
|
113 |
+
if "singletions" in data:
|
114 |
+
self.singletons = set(data["singletions"])
|
115 |
+
else:
|
116 |
+
self.singletons = None
|
117 |
+
|
118 |
+
def save(self, output_directory, name=None):
|
119 |
+
"""
|
120 |
+
Save both alhpabet records to the given directory.
|
121 |
+
:param output_directory: Directory to save model and weights.
|
122 |
+
:param name: The alphabet saving name, optional.
|
123 |
+
:return:
|
124 |
+
"""
|
125 |
+
saving_name = name if name else self.__name
|
126 |
+
try:
|
127 |
+
if not os.path.exists(output_directory):
|
128 |
+
os.makedirs(output_directory)
|
129 |
+
|
130 |
+
json.dump(self.get_content(),
|
131 |
+
open(os.path.join(output_directory, saving_name + ".json"), 'w'), indent=4)
|
132 |
+
except Exception as e:
|
133 |
+
self.logger.warn("Alphabet is not saved: %s" % repr(e))
|
134 |
+
|
135 |
+
def load(self, input_directory, name=None):
|
136 |
+
"""
|
137 |
+
Load model architecture and weights from the give directory. This allow we use old saved_models even the structure
|
138 |
+
changes.
|
139 |
+
:param input_directory: Directory to save model and weights
|
140 |
+
:return:
|
141 |
+
"""
|
142 |
+
loading_name = name if name else self.__name
|
143 |
+
filename = os.path.join(input_directory, loading_name + ".json")
|
144 |
+
f = json.load(open(filename))
|
145 |
+
self.__from_json(f)
|
146 |
+
self.next_index = len(self.instances) + self.offset
|
147 |
+
self.keep_growing = False
|
utils/io_/coarse_to_ma_dict.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"Noun": ["acc. pl. m.", "acc. pl. *", "inst. sg. m.", "abl. du. f.", "nom. sg. *.", "loc. pl. n.", "loc. du. f.", "nom. sg. m", "g. sg. *.", "i. sg. m", "g. pl. *", "nom. pl. n.", "g. du. f.", "acc. du. f.", "loc. du. m.", "nom. sg. *", "nom. sg. m.", "loc. pl. m.", "inst. pl. n.", "g. du. n.", "nom. du. f.", "inst. pl. f.", "acc. sg. n", "abl.sg.m", "loc. pl. *", "abl. pl. n.", "i. pl. n.", "dat. du. m.", "abl. pl. m.", "inst. sg. n.", "acc. pl. n.", "nom. du. *.", "inst. du. n.", "nom. pl. *", "loc. du. *.", "abl. du. n.", "nom. pl. *.", "i. sg. m.", "dat. sg. f.", "inst. sg. *.", "abl. sg. m.", "acc. sg. *", "dat. sg. *", "dat. pl. *", "dat. du. n.", "i. pl. m.", "inst. du. m.", "nom. sg. f.", "g. sg. *", "acc. sg. *.", "i. sg. n.", "i. pl. *", "loc. du. n.", "abl. sg. *", "loc. sg. m.", "nom.sg.m", "i. sg. n..", "abl. sg. f.", "nom. sg. sg.", "dat. du. f.", "nom. sg. n.", "loc. sg. f.", "acc. du. *", "acc. pl. *.", "dat. sg. m.", "nom. sg. n", "nom. pl. f.", "i. pl. f.", "acc. du. m.", "inst. sg. f.", "g. sg. n.", "abl. sg. n", "inst. pl. m.", "loc. pl. f.", "acc. sg. f.", "g. sg. f.", "voc. du. m.", "abl. sg. n.", "voc. sg. *", "dat. sg. n.", "i. sg. *", "acc. du. n.", "g. sg. m.", "loc. sg. n", "dat. pl. m.", "nom. du. *", "nom. du. n.", "acc. sg. m.", "nom. pl. m.", "acc. pl. f.", "acc. sg. f", "g. pl. m.", "voc. du. f.", "g. pl. *.", "voc. du. n.", "dat. pl. f.", "nom. du. m.", "abl. pl. f.", "dat. pl. n.", "dat. sg. m", "loc. sg. *.", "g. pl. n.", "loc. sg. n.", "abl. sg. m", "g. du. m.", "loc. sg. *", "voc. sg. m.", "voc. pl. m.", "voc. sg. n.", "acc. sg. n.", "g. pl. f.", "i. sg. f.", "abl. du. m.", "voc. sg. f.", "voc. pl. f."], "FV": ["impft. [2] ac. du. 3", "imp. [1] ac. du. 1", "imp. [3] ac. pl. 2", "ca. pr. md. sg. 1", "ca. impft. ac. sg. 1", "pp. . pl. ", "imp. [4] ac. sg. 2", "impft . ps. pl. 3", "pr. ac. sg. 1", "pp. . sg. ", "pr. [6] ac. pl. 3", "imp. [1] ac. pl. 2", "cond. ac. sg. 3", "pft. ac. sg. 2", "impft. [6] md. du. 2", "pr. [7] md. sg. 3", "pr. ps. sg. 2", "impft. [1] md. pl. 3", "ca. impft. ac. pl. 1", "pr. [10] ac. sg. 3", "pr. [5] ac. pl. 3", "imp. md. sg. 3", "fut. ac. du. 1", "imp. [10] ac. pl. 2", "impft. [1] md. sg. 3", "impft. [6] ac. sg. 2", "pr. [6] ac. sg. 1", "ca. impft. ac. pl. 3", "opt. [8] ac. sg. 3", "des. pr. md. du. 1", "ca. imp. ac. pl. 1", "pr. [1] md. pl. 3", "aor. [5] ac. sg. 2", "pfp. md. sg. 3", "imp. [2] md. sg. 1", "impft. [3] ac. sg. 3", "fut. ac. pl. 1", "pr. [vn.] ac. pl. 3", "ca. imp. ps. sg. 3", "impft. [10] ac. sg. 3", "imp. ac. pl. 3", "impft. [8] ac. pl. 3", "impft. [10] ac. sg. 2", "impft. [1] ac. pl. 3", "pr. [10] ac. du. 3", "impft. [10] ac. du. 3", "pp. . sg. 1", "pr. [4] md. sg. 3", "opt. [1] ac. pl. 2", "pfp. md. sg. 1", "pr. [2] ac. sg. 3", "opt. ac. pl. 1", "pp. ac. sg. 2", "pr. [5] ac. du. 3", "pr. md. sg. ", "pr. [8] md. sg. 3", "ca. imp. ac. sg. 2", "imp. ac. pl. 1", "imp. ps. sg. 2", "pft. md. pl. ", "ca. impft. md. sg. 3", "pr. [6] md. sg. 3", "opt. [1] ac. sg. 3", "pft. ac. sg. 1", "pft. ac. du. 3", "pr. [vn.] ac. sg. 3", "imp. ps. pl. 3", "opt. ps. sg. 3", "pr. ac. pl. 1", "pft. md. du. 3", "pr. [1] md. sg. 1", "imp. [1] ac. sg. 3", "pr. [3] ac. sg. 1", "pft. ps. sg. 3", "ca. pr. ac. du. 3", "impft. [9] ac. sg. 3", "impft . ac. sg. 1", "inj. [2] ac. sg. 3", "imp. [1] ac. sg. 2", "pr. [6] ac. du. 3", "fut. ac. du. 3", "imp. [vn.] ac. sg. 2", "imp. ac. sg. 2", "pr. ps. pl. 3", "pr. [6] ac. pl. 2", "ppr. . pl. ", "ppr. . sg. ", "pr. md. sg. 1", "imp. [9] ac. sg. 3", "impft. md. sg. ", "aor. [2] ac. sg. 3", "impft. md. sg. 2", "pr. md. sg. 3", "impft. [2] ac. sg. 1", "fut. ac. sg. 2", "opt. [2] ac. sg. 3", "imp. [8] ac. sg. 3", "imp. [10] ac. sg. 2", "pr. [4] ac. sg. 3", "pr. [8] ac. sg. 3", "fut. md. sg. 1", "opt. [8] md. sg. 3", "ca. pr. md. sg. 3", "aor. [5] ac. sg. 3", "impft. [1] md. sg. 2", "pr. [1] md. du. 3", "cond. ac. pl. 3", "imp. [4] ac. pl. 2", "impft. [1] ac. du. 1", "impft. [4] ac. pl. 3", "pr. [3] ac. pl. 3", "impft. [1] md. sg. 1", "pr. md. . 3", "imp. [1] ac. du. 3", "pr. ac. du. 3", "pr. [9] ac. sg. 2", "pr. [2] ac. du. 1", "aor. [4] ac. du. 2", "pr. [1] ac. sg. 1", "opt. [5] ac. sg. 3", "opt. [10] ac. pl. 2", "imp. [10] ac. du. 3", "aor. [1] ac. sg. 1", "pr. [10] ac. pl. 3", "opt. [2] ac. pl. 3", "opt. [1] ac. du. 3", "pr. [4] md. pl. 3", "impft. ac. sg. 2", "impft . ac. du. 3", "cond. ac. sg. 2", "pr. . sg. 3", "impft. [2] ac. sg. 3", "opt. [6] ac. sg. 3", "pr. [5] ac. sg. 3", "pft. ps. du. 3", "des. imp. md. du. 3", "impft . ac. sg. 3", "impft . ac. sg. ", "pr. md. sg. 2", "impft. [vn.] ac. sg. 3", "pr. [2] ac. sg. 2", "pr. [6] ac. sg. 3", "pr. ac. sg. 2", "pr. [7] ac. pl. 3", "imp. [3] ac. sg. 2", "opt. [2] ac. sg. 1", "pr. [1] md. pl. 2", "pft. md. pl. 3", "pr. [1] ac. sg. 3", "pr. [1] ac. pl. 1", "pr. [3] ac. du. 3", "impft. md. du. 3", "opt. [10] ac. du. 2", "imp. ac. du. 2", "opt. [1] ac. pl. 3", "impft. [vn.] md. sg. 3", "impft . ac. sg. 2", "aor. [4] ac. sg. 3", "ca. fut. ac. sg. 1", "imp. ac. sg. 3", "fut. ac. pl. 2", "impft. md. pl. 3", "ca. fut. ac. sg. 3", "opt. [2] ac. sg. 2", "impft. ac. pl. 3", "pr. [10] ac. sg. 1", "impft. md. sg. 3", "pr. [6] ac. sg. 2", "ca. pr. ps. sg. 3", "impft. ac. sg. 3", "imp. [4] ac. sg. 3", "pr. ps. sg. 1", "pr. [2] md. sg. 3", "imp. [1] ac. du. 2", "impft. [4] ac. sg. 3", "ca. opt. ac. sg. 3", "ca. pr. ac. sg. 3", "opt. ac. pl. 3", "imp. [1] md. sg. 3", "pr. [10] ac. sg. 2", "ca. pr. ps. du. 3", "opt. [5] ac. sg. 1", "impft. ps. pl. 3", "pr. [5] md. sg. 3", "pft. ps. pl. 3", "pft. ac. pl. 2", "ca. opt. ps. sg. 3", "impft. [6] ac. pl. 1", "imp. md. pl. 3", "pr. ps. du. 3", "ca. pr. ac. sg. 1", "opt. md. sg. 3", "fut. md. sg. 3", "impft . ac. pl. 3", "pr. [1] md. sg. 2", "pr. md. sg. 3", "opt. [8] ac. pl. 3", "opt. ac. du. 3", "impft. [2] ac. pl. 3", "imp. [2] ac. sg. 2", "imp. ac. pl. 2", "opt. [10] ac. sg. 3", "pft. md. sg. 1", "pr. [8] ac. sg. 1", "imp. [1] md. sg. 2", "pr. [vn.] md. sg. 3", "ca. pr. ps. sg. 1", "ca. imp. ac. sg. 3", "imp. [2] ac. du. 2", "imp. [6] ac. sg. 3", "opt. [1] ac. sg. 1", "ca. opt. ac. sg. 1", "imp. md. sg. 2", "pft. md. sg. 3", "imp. [2] ac. sg. 3", "impft. [1] ac. sg. 3", "ca. pr. ac. sg. 2", "pr. [1] md. pl. 1", "per. fut. md. sg. 3", "impft. [2] ac. sg. 2", "opt. ac. sg. 2", "imp. [3] ac. sg. 3", "imp. [3] ac. du. 2", "imp. [vn.] ac. pl. 3", "ca. pr. ac. pl. 3", "pr. [1] ac. sg. 2", "imp. [1] ac. pl. 3", "impft. [6] md. pl. 3", "per. fut. ac. sg. 3", "pr. md. pl. 3", "impft. [4] md. pl. 3", "opt. [vn.] ac. sg. 3", "aor. [1] ac. sg. 3", "impft. [4] md. sg. 3", "impft . md. sg. 3", "pr. [1] ac. pl. 3", "pft. ac. pl. 3", "opt. ac. sg. 3", "impft. ps. sg. 3", "impft. [6] ac. sg. 3", "pr. [1] ac. du. 3", "imp. [4] md. sg. 3", "per. fut. ac. sg. 3", "impft. [1] ac. sg. 1", "pr. [2] ac. du. 3", "pr. [2] ac. pl. 1", "impft. [1] ac. du. 3", "fut. ac. du. 2", "imp. [9] ac. sg. 2", "opt. [8] ac. sg. 2", "pr. [9] ac. sg. 3", "des. pr. md. pl. 3", "imp. [8] ac. sg. 2", "inj. [4] md. sg. 3", "fut. ac. pl. 3", "pr. [4] ac. pl. 3", "imp. md. sg. 1", "imp. [6] ac. sg. 2", "imp. ps. sg. 3", "imp. [vn.] ac. sg. 3", "pfp. . pl. ", "impft. [3] ac. pl. 3", "impft. [4] ac. sg. 1", "fut. ac. sg. 1", "per. fut. ac. sg. 2", "pr. [2] ac. pl. 3", "pr. [2] ac. sg. 1", "opt. md. . 3", "des. impft. ac. pl. 2", "opt. [3] ac. pl. 2", "pr. [5] ac. sg. 1", "pr. md. du. 3", "imp. [10] ac. sg. 3", "imp. [2] md. sg. 3", "pft. ac. sg. 3", "impft. [8] ac. sg. 3", "imp. [5] ac. sg. 2", "impft. [8] ac. du. 3", "impft . ps. sg. 3", "ca. opt. ac. sg. 2", "per. fut. ac. pl. 3", "pr. [9] ac. sg. 1", "impft . ps. du. 3", "aor. [1] ac. pl. 3", "pr. [3] ac. sg. 3", "opt. [3] ac. sg. 1", "impft. ac. sg. ", "aor. [1] ac. du. 3", "pr. [1] md. sg. 3", "impft. [1] md. pl. 1", "des. pr. md. sg. 3", "opt. [8] ac. sg. 1", "imp. [4] ac. pl. 3", "imp. [4] ac. du. 3", "pr. ac. sg. 3", "opt. ac. du. 1", "impft. [10] ac. pl. 3", "pr. ac. pl. 3", "pfp. . sg. ", "impft . ac. pl. 2", "fut. ac. sg. 3", "fut. md. pl. 3", "impft. [5] ac. sg. 3", "impft. [1] ac. sg. 2", "pr. [8] ac. pl. 3", "pr. [8] ac. du. 3", "opt. [8] ac. pl. 1", "ca. impft. ac. sg. 3", "opt. ac. sg. 1", "impft. [5] ac. sg. 1", "pp. . du. ", "pr. md. pl. 1", "impft . md. pl. 3", "pr. ps. sg. 3", "des. impft. md. pl. 1", "opt. [10] ac. sg. 2", "impft. [8] md. sg. 3", "pr. [6] md. pl. 3", "imp. [1] md. pl. 2", "opt. [1] md. sg. 3", "opt. [3] md. sg. 3", "ca. fut. md. sg. 3", "pr. [9] ac. pl. 3"], "IV": ["pp.", "inf.", "ca. ppr. ps.", "pp. . . ", "des. ppr. ac.", "ca. pfp. [1]", "ppr. ps.", "ppr.", "abs.", "ppr. [6] ac.", "ppa.", "pfp. [3]", "pfp. [1]", "ppr. [1] ac.", "impft . ac. . ", "ppr. [1] md.", "ca. inf.", "part.", "ca. pfp. [3]", "pfp.", "ppr. [8] ac.", "ca. abs.", "ca. pp.", "ca. ppa.", "ca. ppr. ac.", "ppr. [10] ac.", "ppr. [2] ac.", "pfu. ac.", "tasil", "pfp. [2]", "ca. pfp. [2]", "ppr. [6] ac. ", "pp. [2]"], "Ind": ["prep.", "ind.", "conj."], "adv": ["adv."]}
|
utils/io_/convert_ud_to_onto_format.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
# /home/jivnesh/Documents/DCST/d
|
4 |
+
# 'ar','eu','cs','de','hu','lv','pl','sv','fi','ru'
|
5 |
+
# python utils/io_/convert_ud_to_onto_format.py --ud_data_path d
|
6 |
+
def write_ud_files(args):
|
7 |
+
languages_for_low_resource = ['el']
|
8 |
+
|
9 |
+
languages = sorted(list(set(languages_for_low_resource)))
|
10 |
+
splits = ['train', 'dev', 'test']
|
11 |
+
lng_to_files = dict((language, {}) for language in languages)
|
12 |
+
for language, d in lng_to_files.items():
|
13 |
+
for split in splits:
|
14 |
+
d[split] = []
|
15 |
+
lng_to_files[language] = d
|
16 |
+
sub_folders = os.listdir(args.ud_data_path)
|
17 |
+
for sub_folder in sub_folders:
|
18 |
+
folder = os.path.join(args.ud_data_path, sub_folder)
|
19 |
+
files = os.listdir(folder)
|
20 |
+
for file in files:
|
21 |
+
for language in languages:
|
22 |
+
if file.startswith(language) and file.endswith('conllu'):
|
23 |
+
for split in splits:
|
24 |
+
if split in file:
|
25 |
+
full_path = os.path.join(folder, file)
|
26 |
+
lng_to_files[language][split].append(full_path)
|
27 |
+
break
|
28 |
+
|
29 |
+
for language, split_dict in lng_to_files.items():
|
30 |
+
for split, files in split_dict.items():
|
31 |
+
if split == 'dev' and len(files) == 0:
|
32 |
+
files = split_dict['train']
|
33 |
+
print('No dev files were found, copying train files instead')
|
34 |
+
sentences = []
|
35 |
+
num_sentences = 0
|
36 |
+
for file in files:
|
37 |
+
with open(file, 'r') as file:
|
38 |
+
for line in file:
|
39 |
+
new_line = []
|
40 |
+
line = line.strip()
|
41 |
+
if len(line) == 0:
|
42 |
+
sentences.append(new_line)
|
43 |
+
num_sentences += 1
|
44 |
+
continue
|
45 |
+
tokens = line.split('\t')
|
46 |
+
if not tokens[0].isdigit():
|
47 |
+
continue
|
48 |
+
id = tokens[0]
|
49 |
+
word = tokens[1]
|
50 |
+
pos = tokens[3]
|
51 |
+
ner = tokens[5]
|
52 |
+
head = tokens[6]
|
53 |
+
arc_tag = tokens[7]
|
54 |
+
new_line = [id, word, pos, ner, head, arc_tag]
|
55 |
+
sentences.append(new_line)
|
56 |
+
print('Language: %s Split: %s Num. Sentences: %s ' % (language, split, num_sentences))
|
57 |
+
if not os.path.exists('data'):
|
58 |
+
os.makedirs('data')
|
59 |
+
write_data_path = 'data/MRL/ud_pos_ner_dp_' + split + '_' + language
|
60 |
+
print('creating %s' % write_data_path)
|
61 |
+
with open(write_data_path, 'w') as f:
|
62 |
+
for line in sentences:
|
63 |
+
f.write('\t'.join(line) + '\n')
|
64 |
+
|
65 |
+
def main():
|
66 |
+
# Parse arguments
|
67 |
+
args_ = argparse.ArgumentParser()
|
68 |
+
args_.add_argument('--ud_data_path', help='Directory path of the UD treebanks.', required=True)
|
69 |
+
|
70 |
+
args = args_.parse_args()
|
71 |
+
write_ud_files(args)
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
main()
|
utils/io_/instance.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Sentence(object):
|
2 |
+
def __init__(self, words, word_ids, char_seqs, char_id_seqs):
|
3 |
+
self.words = words
|
4 |
+
self.word_ids = word_ids
|
5 |
+
self.char_seqs = char_seqs
|
6 |
+
self.char_id_seqs = char_id_seqs
|
7 |
+
|
8 |
+
def length(self):
|
9 |
+
return len(self.words)
|
10 |
+
|
11 |
+
class NER_DependencyInstance(object):
|
12 |
+
def __init__(self, sentence, tokens_dict, ids_dict, heads):
|
13 |
+
self.sentence = sentence
|
14 |
+
self.tokens = tokens_dict
|
15 |
+
self.ids = ids_dict
|
16 |
+
self.heads = heads
|
17 |
+
|
18 |
+
def length(self):
|
19 |
+
return self.sentence.length()
|
utils/io_/logger.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
|
4 |
+
|
5 |
+
def get_logger(name, level=logging.INFO, handler=sys.stdout,
|
6 |
+
formatter='%(asctime)s - %(name)s - %(levelname)s - %(message)s'):
|
7 |
+
logger = logging.getLogger(name)
|
8 |
+
logger.setLevel(logging.INFO)
|
9 |
+
formatter = logging.Formatter(formatter)
|
10 |
+
stream_handler = logging.StreamHandler(handler)
|
11 |
+
stream_handler.setLevel(level)
|
12 |
+
stream_handler.setFormatter(formatter)
|
13 |
+
logger.addHandler(stream_handler)
|
14 |
+
|
15 |
+
return logger
|
utils/io_/prepare_data.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import numpy as np
|
3 |
+
from .alphabet import Alphabet
|
4 |
+
from .logger import get_logger
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# Special vocabulary symbols - we always put them at the end.
|
8 |
+
PAD = "_<PAD>_"
|
9 |
+
ROOT = "_<ROOT>_"
|
10 |
+
END = "_<END>_"
|
11 |
+
_START_VOCAB = [PAD, ROOT, END]
|
12 |
+
|
13 |
+
MAX_CHAR_LENGTH = 45
|
14 |
+
NUM_CHAR_PAD = 2
|
15 |
+
|
16 |
+
UNK_ID = 0
|
17 |
+
PAD_ID_WORD = 1
|
18 |
+
PAD_ID_CHAR = 1
|
19 |
+
PAD_ID_TAG = 0
|
20 |
+
|
21 |
+
NUM_SYMBOLIC_TAGS = 3
|
22 |
+
|
23 |
+
_buckets = [10, 15, 20, 25, 30, 35, 40, 50, 60, 70, 80, 90, 100, 140]
|
24 |
+
|
25 |
+
from .reader import Reader
|
26 |
+
|
27 |
+
def create_alphabets(alphabet_directory, train_paths, extra_paths=None, max_vocabulary_size=100000, embedd_dict=None,
|
28 |
+
min_occurence=1, lower_case=False):
|
29 |
+
def expand_vocab(vocab_list, char_alphabet, pos_alphabet, ner_alphabet, arc_alphabet):
|
30 |
+
vocab_set = set(vocab_list)
|
31 |
+
for data_path in extra_paths:
|
32 |
+
with open(data_path, 'r') as file:
|
33 |
+
for line in file:
|
34 |
+
line = line.strip()
|
35 |
+
if len(line) == 0:
|
36 |
+
continue
|
37 |
+
|
38 |
+
tokens = line.split('\t')
|
39 |
+
if lower_case:
|
40 |
+
tokens[1] = tokens[1].lower()
|
41 |
+
for char in tokens[1]:
|
42 |
+
char_alphabet.add(char)
|
43 |
+
|
44 |
+
word = tokens[1]
|
45 |
+
pos = tokens[2]
|
46 |
+
ner = tokens[3]
|
47 |
+
arc_tag = tokens[5]
|
48 |
+
|
49 |
+
pos_alphabet.add(pos)
|
50 |
+
ner_alphabet.add(ner)
|
51 |
+
arc_alphabet.add(arc_tag)
|
52 |
+
if embedd_dict is not None:
|
53 |
+
if word not in vocab_set and (word in embedd_dict or word.lower() in embedd_dict):
|
54 |
+
vocab_set.add(word)
|
55 |
+
vocab_list.append(word)
|
56 |
+
else:
|
57 |
+
if word not in vocab_set:
|
58 |
+
vocab_set.add(word)
|
59 |
+
vocab_list.append(word)
|
60 |
+
return vocab_list, char_alphabet, pos_alphabet, ner_alphabet, arc_alphabet
|
61 |
+
|
62 |
+
logger = get_logger("Create Alphabets")
|
63 |
+
word_alphabet = Alphabet('word', defualt_value=True, singleton=True)
|
64 |
+
char_alphabet = Alphabet('character', defualt_value=True)
|
65 |
+
pos_alphabet = Alphabet('pos', defualt_value=True)
|
66 |
+
ner_alphabet = Alphabet('ner', defualt_value=True)
|
67 |
+
arc_alphabet = Alphabet('arc', defualt_value=True)
|
68 |
+
auto_label_alphabet = Alphabet('auto_labeler', defualt_value=True)
|
69 |
+
if not os.path.isdir(alphabet_directory):
|
70 |
+
logger.info("Creating Alphabets: %s" % alphabet_directory)
|
71 |
+
|
72 |
+
char_alphabet.add(PAD)
|
73 |
+
pos_alphabet.add(PAD)
|
74 |
+
ner_alphabet.add(PAD)
|
75 |
+
arc_alphabet.add(PAD)
|
76 |
+
auto_label_alphabet.add(PAD)
|
77 |
+
|
78 |
+
char_alphabet.add(ROOT)
|
79 |
+
pos_alphabet.add(ROOT)
|
80 |
+
ner_alphabet.add(ROOT)
|
81 |
+
arc_alphabet.add(ROOT)
|
82 |
+
auto_label_alphabet.add(ROOT)
|
83 |
+
|
84 |
+
char_alphabet.add(END)
|
85 |
+
pos_alphabet.add(END)
|
86 |
+
ner_alphabet.add(END)
|
87 |
+
arc_alphabet.add(END)
|
88 |
+
auto_label_alphabet.add(END)
|
89 |
+
|
90 |
+
vocab = dict()
|
91 |
+
if isinstance(train_paths, str):
|
92 |
+
train_paths = [train_paths]
|
93 |
+
for train_path in train_paths:
|
94 |
+
with open(train_path, 'r') as file:
|
95 |
+
for line in file:
|
96 |
+
line = line.strip()
|
97 |
+
if len(line) == 0:
|
98 |
+
continue
|
99 |
+
|
100 |
+
tokens = line.split('\t')
|
101 |
+
if lower_case:
|
102 |
+
tokens[1] = tokens[1].lower()
|
103 |
+
for char in tokens[1]:
|
104 |
+
char_alphabet.add(char)
|
105 |
+
|
106 |
+
word = tokens[1]
|
107 |
+
# print(word)
|
108 |
+
pos = tokens[2]
|
109 |
+
ner = tokens[3]
|
110 |
+
arc_tag = tokens[5]
|
111 |
+
|
112 |
+
pos_alphabet.add(pos)
|
113 |
+
ner_alphabet.add(ner)
|
114 |
+
arc_alphabet.add(arc_tag)
|
115 |
+
|
116 |
+
if word in vocab:
|
117 |
+
vocab[word] += 1
|
118 |
+
else:
|
119 |
+
vocab[word] = 1
|
120 |
+
|
121 |
+
# collect singletons
|
122 |
+
singletons = set([word for word, count in vocab.items() if count <= min_occurence])
|
123 |
+
|
124 |
+
# if a singleton is in pretrained embedding dict, set the count to min_occur + c
|
125 |
+
if embedd_dict is not None:
|
126 |
+
for word in vocab.keys():
|
127 |
+
if word in embedd_dict or word.lower() in embedd_dict:
|
128 |
+
vocab[word] += min_occurence
|
129 |
+
|
130 |
+
vocab_list = sorted(vocab, key=vocab.get, reverse=True)
|
131 |
+
vocab_list = [word for word in vocab_list if vocab[word] > min_occurence]
|
132 |
+
vocab_list = _START_VOCAB + vocab_list
|
133 |
+
|
134 |
+
if extra_paths is not None:
|
135 |
+
vocab_list, char_alphabet, pos_alphabet, ner_alphabet, arc_alphabet = \
|
136 |
+
expand_vocab(vocab_list, char_alphabet, pos_alphabet, ner_alphabet, arc_alphabet)
|
137 |
+
|
138 |
+
if len(vocab_list) > max_vocabulary_size:
|
139 |
+
vocab_list = vocab_list[:max_vocabulary_size]
|
140 |
+
|
141 |
+
for word in vocab_list:
|
142 |
+
word_alphabet.add(word)
|
143 |
+
if word in singletons:
|
144 |
+
word_alphabet.add_singleton(word_alphabet.get_index(word))
|
145 |
+
|
146 |
+
word_alphabet.save(alphabet_directory)
|
147 |
+
char_alphabet.save(alphabet_directory)
|
148 |
+
pos_alphabet.save(alphabet_directory)
|
149 |
+
ner_alphabet.save(alphabet_directory)
|
150 |
+
arc_alphabet.save(alphabet_directory)
|
151 |
+
auto_label_alphabet.save(alphabet_directory)
|
152 |
+
|
153 |
+
else:
|
154 |
+
print('loading saved alphabet from %s' % alphabet_directory)
|
155 |
+
word_alphabet.load(alphabet_directory)
|
156 |
+
char_alphabet.load(alphabet_directory)
|
157 |
+
pos_alphabet.load(alphabet_directory)
|
158 |
+
ner_alphabet.load(alphabet_directory)
|
159 |
+
arc_alphabet.load(alphabet_directory)
|
160 |
+
auto_label_alphabet.load(alphabet_directory)
|
161 |
+
|
162 |
+
word_alphabet.close()
|
163 |
+
char_alphabet.close()
|
164 |
+
pos_alphabet.close()
|
165 |
+
ner_alphabet.close()
|
166 |
+
arc_alphabet.close()
|
167 |
+
auto_label_alphabet.close()
|
168 |
+
|
169 |
+
alphabet_dict = {'word_alphabet': word_alphabet, 'char_alphabet': char_alphabet, 'pos_alphabet': pos_alphabet,
|
170 |
+
'ner_alphabet': ner_alphabet, 'arc_alphabet': arc_alphabet, 'auto_label_alphabet': auto_label_alphabet}
|
171 |
+
return alphabet_dict
|
172 |
+
|
173 |
+
def create_alphabets_for_sequence_tagger(alphabet_directory, parser_alphabet_directory, paths):
|
174 |
+
logger = get_logger("Create Alphabets")
|
175 |
+
print('loading saved alphabet from %s' % parser_alphabet_directory)
|
176 |
+
word_alphabet = Alphabet('word', defualt_value=True, singleton=True)
|
177 |
+
char_alphabet = Alphabet('character', defualt_value=True)
|
178 |
+
pos_alphabet = Alphabet('pos', defualt_value=True)
|
179 |
+
ner_alphabet = Alphabet('ner', defualt_value=True)
|
180 |
+
arc_alphabet = Alphabet('arc', defualt_value=True)
|
181 |
+
auto_label_alphabet = Alphabet('auto_labeler', defualt_value=True)
|
182 |
+
|
183 |
+
word_alphabet.load(parser_alphabet_directory)
|
184 |
+
char_alphabet.load(parser_alphabet_directory)
|
185 |
+
pos_alphabet.load(parser_alphabet_directory)
|
186 |
+
ner_alphabet.load(parser_alphabet_directory)
|
187 |
+
arc_alphabet.load(parser_alphabet_directory)
|
188 |
+
try:
|
189 |
+
auto_label_alphabet.load(alphabet_directory)
|
190 |
+
except:
|
191 |
+
print('Creating auto labeler alphabet')
|
192 |
+
auto_label_alphabet.add(PAD)
|
193 |
+
auto_label_alphabet.add(ROOT)
|
194 |
+
auto_label_alphabet.add(END)
|
195 |
+
for path in paths:
|
196 |
+
with open(path, 'r') as file:
|
197 |
+
for line in file:
|
198 |
+
line = line.strip()
|
199 |
+
if len(line) == 0:
|
200 |
+
continue
|
201 |
+
tokens = line.split('\t')
|
202 |
+
if len(tokens) > 6:
|
203 |
+
auto_label = tokens[6]
|
204 |
+
auto_label_alphabet.add(auto_label)
|
205 |
+
|
206 |
+
word_alphabet.save(alphabet_directory)
|
207 |
+
char_alphabet.save(alphabet_directory)
|
208 |
+
pos_alphabet.save(alphabet_directory)
|
209 |
+
ner_alphabet.save(alphabet_directory)
|
210 |
+
arc_alphabet.save(alphabet_directory)
|
211 |
+
auto_label_alphabet.save(alphabet_directory)
|
212 |
+
word_alphabet.close()
|
213 |
+
char_alphabet.close()
|
214 |
+
pos_alphabet.close()
|
215 |
+
ner_alphabet.close()
|
216 |
+
arc_alphabet.close()
|
217 |
+
auto_label_alphabet.close()
|
218 |
+
alphabet_dict = {'word_alphabet': word_alphabet, 'char_alphabet': char_alphabet, 'pos_alphabet': pos_alphabet,
|
219 |
+
'ner_alphabet': ner_alphabet, 'arc_alphabet': arc_alphabet, 'auto_label_alphabet': auto_label_alphabet}
|
220 |
+
return alphabet_dict
|
221 |
+
|
222 |
+
def read_data(source_path, alphabets, max_size=None,
|
223 |
+
lower_case=False, symbolic_root=False, symbolic_end=False):
|
224 |
+
data = [[] for _ in _buckets]
|
225 |
+
max_char_length = [0 for _ in _buckets]
|
226 |
+
print('Reading data from %s' % ', '.join(source_path) if type(source_path) is list else source_path)
|
227 |
+
counter = 0
|
228 |
+
if type(source_path) is not list:
|
229 |
+
source_path = [source_path]
|
230 |
+
for path in source_path:
|
231 |
+
reader = Reader(path, alphabets)
|
232 |
+
inst = reader.getNext(lower_case=lower_case, symbolic_root=symbolic_root, symbolic_end=symbolic_end)
|
233 |
+
while inst is not None and (not max_size or counter < max_size):
|
234 |
+
counter += 1
|
235 |
+
inst_size = inst.length()
|
236 |
+
sent = inst.sentence
|
237 |
+
for bucket_id, bucket_size in enumerate(_buckets):
|
238 |
+
if inst_size < bucket_size:
|
239 |
+
data[bucket_id].append([sent.word_ids, sent.char_id_seqs, inst.ids['pos_alphabet'], inst.ids['ner_alphabet'],
|
240 |
+
inst.heads, inst.ids['arc_alphabet'], inst.ids['auto_label_alphabet']])
|
241 |
+
max_len = max([len(char_seq) for char_seq in sent.char_seqs])
|
242 |
+
if max_char_length[bucket_id] < max_len:
|
243 |
+
max_char_length[bucket_id] = max_len
|
244 |
+
break
|
245 |
+
|
246 |
+
inst = reader.getNext(lower_case=lower_case, symbolic_root=symbolic_root, symbolic_end=symbolic_end)
|
247 |
+
reader.close()
|
248 |
+
print("Total number of data: %d" % counter)
|
249 |
+
return data, max_char_length
|
250 |
+
|
251 |
+
def read_data_to_variable(source_path, alphabets, device, max_size=None,
|
252 |
+
lower_case=False, symbolic_root=False, symbolic_end=False):
|
253 |
+
data, max_char_length = read_data(source_path, alphabets,
|
254 |
+
max_size=max_size, lower_case=lower_case,
|
255 |
+
symbolic_root=symbolic_root, symbolic_end=symbolic_end)
|
256 |
+
bucket_sizes = [len(data[b]) for b in range(len(_buckets))]
|
257 |
+
|
258 |
+
data_variable = []
|
259 |
+
|
260 |
+
for bucket_id in range(len(_buckets)):
|
261 |
+
bucket_size = bucket_sizes[bucket_id]
|
262 |
+
if bucket_size <= 0:
|
263 |
+
data_variable.append((1, 1))
|
264 |
+
continue
|
265 |
+
|
266 |
+
bucket_length = _buckets[bucket_id]
|
267 |
+
char_length = min(MAX_CHAR_LENGTH, max_char_length[bucket_id] + NUM_CHAR_PAD)
|
268 |
+
wid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
|
269 |
+
cid_inputs = np.empty([bucket_size, bucket_length, char_length], dtype=np.int64)
|
270 |
+
pid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
|
271 |
+
nid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
|
272 |
+
hid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
|
273 |
+
aid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
|
274 |
+
mid_inputs = np.empty([bucket_size, bucket_length], dtype=np.int64)
|
275 |
+
|
276 |
+
masks = np.zeros([bucket_size, bucket_length], dtype=np.float32)
|
277 |
+
single = np.zeros([bucket_size, bucket_length], dtype=np.int64)
|
278 |
+
|
279 |
+
lengths = np.empty(bucket_size, dtype=np.int64)
|
280 |
+
|
281 |
+
for i, inst in enumerate(data[bucket_id]):
|
282 |
+
wids, cid_seqs, pids, nids, hids, aids, mids = inst
|
283 |
+
inst_size = len(wids)
|
284 |
+
lengths[i] = inst_size
|
285 |
+
# word ids
|
286 |
+
wid_inputs[i, :inst_size] = wids
|
287 |
+
wid_inputs[i, inst_size:] = PAD_ID_WORD
|
288 |
+
for c, cids in enumerate(cid_seqs):
|
289 |
+
cid_inputs[i, c, :len(cids)] = cids
|
290 |
+
cid_inputs[i, c, len(cids):] = PAD_ID_CHAR
|
291 |
+
cid_inputs[i, inst_size:, :] = PAD_ID_CHAR
|
292 |
+
# pos ids
|
293 |
+
pid_inputs[i, :inst_size] = pids
|
294 |
+
pid_inputs[i, inst_size:] = PAD_ID_TAG
|
295 |
+
# ner ids
|
296 |
+
nid_inputs[i, :inst_size] = nids
|
297 |
+
nid_inputs[i, inst_size:] = PAD_ID_TAG
|
298 |
+
# arc ids
|
299 |
+
aid_inputs[i, :inst_size] = aids
|
300 |
+
aid_inputs[i, inst_size:] = PAD_ID_TAG
|
301 |
+
# auto_label ids
|
302 |
+
mid_inputs[i, :inst_size] = mids
|
303 |
+
mid_inputs[i, inst_size:] = PAD_ID_TAG
|
304 |
+
# heads
|
305 |
+
hid_inputs[i, :inst_size] = hids
|
306 |
+
hid_inputs[i, inst_size:] = PAD_ID_TAG
|
307 |
+
# masks
|
308 |
+
masks[i, :inst_size] = 1.0
|
309 |
+
for j, wid in enumerate(wids):
|
310 |
+
if alphabets['word_alphabet'].is_singleton(wid):
|
311 |
+
single[i, j] = 1
|
312 |
+
|
313 |
+
words = torch.LongTensor(wid_inputs)
|
314 |
+
chars = torch.LongTensor(cid_inputs)
|
315 |
+
pos = torch.LongTensor(pid_inputs)
|
316 |
+
ner = torch.LongTensor(nid_inputs)
|
317 |
+
heads = torch.LongTensor(hid_inputs)
|
318 |
+
arc = torch.LongTensor(aid_inputs)
|
319 |
+
auto_label = torch.LongTensor(mid_inputs)
|
320 |
+
masks = torch.FloatTensor(masks)
|
321 |
+
single = torch.LongTensor(single)
|
322 |
+
lengths = torch.LongTensor(lengths)
|
323 |
+
words = words.to(device)
|
324 |
+
chars = chars.to(device)
|
325 |
+
pos = pos.to(device)
|
326 |
+
ner = ner.to(device)
|
327 |
+
heads = heads.to(device)
|
328 |
+
arc = arc.to(device)
|
329 |
+
auto_label = auto_label.to(device)
|
330 |
+
masks = masks.to(device)
|
331 |
+
single = single.to(device)
|
332 |
+
lengths = lengths.to(device)
|
333 |
+
|
334 |
+
data_variable.append((words, chars, pos, ner, heads, arc, auto_label, masks, single, lengths))
|
335 |
+
|
336 |
+
return data_variable, bucket_sizes
|
337 |
+
|
338 |
+
def iterate_batch(data, batch_size, device, unk_replace=0.0, shuffle=False):
|
339 |
+
data_variable, bucket_sizes = data
|
340 |
+
|
341 |
+
bucket_indices = np.arange(len(_buckets))
|
342 |
+
if shuffle:
|
343 |
+
np.random.shuffle((bucket_indices))
|
344 |
+
|
345 |
+
for bucket_id in bucket_indices:
|
346 |
+
bucket_size = bucket_sizes[bucket_id]
|
347 |
+
bucket_length = _buckets[bucket_id]
|
348 |
+
if bucket_size <= 0:
|
349 |
+
continue
|
350 |
+
|
351 |
+
words, chars, pos, ner, heads, arc, auto_label, masks, single, lengths = data_variable[bucket_id]
|
352 |
+
if unk_replace:
|
353 |
+
ones = single.data.new(bucket_size, bucket_length).fill_(1)
|
354 |
+
noise = masks.data.new(bucket_size, bucket_length).bernoulli_(unk_replace).long()
|
355 |
+
words = words * (ones - single * noise)
|
356 |
+
|
357 |
+
indices = None
|
358 |
+
if shuffle:
|
359 |
+
indices = torch.randperm(bucket_size).long()
|
360 |
+
indices = indices.to(device)
|
361 |
+
for start_idx in range(0, bucket_size, batch_size):
|
362 |
+
if shuffle:
|
363 |
+
excerpt = indices[start_idx:start_idx + batch_size]
|
364 |
+
else:
|
365 |
+
excerpt = slice(start_idx, start_idx + batch_size)
|
366 |
+
yield words[excerpt], chars[excerpt], pos[excerpt], ner[excerpt], heads[excerpt], arc[excerpt], auto_label[excerpt], \
|
367 |
+
masks[excerpt], lengths[excerpt]
|
368 |
+
|
369 |
+
def iterate_batch_rand_bucket_choosing(data, batch_size, device, unk_replace=0.0):
|
370 |
+
data_variable, bucket_sizes = data
|
371 |
+
indices_left = [set(np.arange(bucket_size)) for bucket_size in bucket_sizes]
|
372 |
+
while sum(bucket_sizes) > 0:
|
373 |
+
non_empty_buckets = [i for i, bucket_size in enumerate(bucket_sizes) if bucket_size > 0]
|
374 |
+
bucket_id = np.random.choice(non_empty_buckets)
|
375 |
+
bucket_size = bucket_sizes[bucket_id]
|
376 |
+
bucket_length = _buckets[bucket_id]
|
377 |
+
|
378 |
+
words, chars, pos, ner, heads, arc, auto_label, masks, single, lengths = data_variable[bucket_id]
|
379 |
+
min_batch_size = min(bucket_size, batch_size)
|
380 |
+
indices = torch.LongTensor(np.random.choice(list(indices_left[bucket_id]), min_batch_size, replace=False))
|
381 |
+
set_indices = set(indices.numpy())
|
382 |
+
indices_left[bucket_id] = indices_left[bucket_id].difference(set_indices)
|
383 |
+
indices = indices.to(device)
|
384 |
+
words = words[indices]
|
385 |
+
if unk_replace:
|
386 |
+
ones = single.data.new(min_batch_size, bucket_length).fill_(1)
|
387 |
+
noise = masks.data.new(min_batch_size, bucket_length).bernoulli_(unk_replace).long()
|
388 |
+
words = words * (ones - single[indices] * noise)
|
389 |
+
bucket_sizes = [len(s) for s in indices_left]
|
390 |
+
yield words, chars[indices], pos[indices], ner[indices], heads[indices], arc[indices], auto_label[indices], masks[indices], lengths[indices]
|
391 |
+
|
392 |
+
|
393 |
+
def calc_num_batches(data, batch_size):
|
394 |
+
_, bucket_sizes = data
|
395 |
+
bucket_sizes_mod_batch_size = [int(bucket_size / batch_size) + 1 if bucket_size > 0 else 0 for bucket_size in bucket_sizes]
|
396 |
+
num_batches = sum(bucket_sizes_mod_batch_size)
|
397 |
+
return num_batches
|
utils/io_/reader.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .instance import NER_DependencyInstance
|
2 |
+
from .instance import Sentence
|
3 |
+
from .prepare_data import ROOT, END, MAX_CHAR_LENGTH
|
4 |
+
|
5 |
+
class Reader(object):
|
6 |
+
def __init__(self, file_path, alphabets):
|
7 |
+
self.__source_file = open(file_path, 'r')
|
8 |
+
self.alphabets = alphabets
|
9 |
+
|
10 |
+
def close(self):
|
11 |
+
self.__source_file.close()
|
12 |
+
|
13 |
+
def getNext(self, lower_case=False, symbolic_root=False, symbolic_end=False):
|
14 |
+
line = self.__source_file.readline()
|
15 |
+
# skip multiple blank lines.
|
16 |
+
while len(line) > 0 and len(line.strip()) == 0:
|
17 |
+
line = self.__source_file.readline()
|
18 |
+
if len(line) == 0:
|
19 |
+
return None
|
20 |
+
|
21 |
+
lines = []
|
22 |
+
while len(line.strip()) > 0:
|
23 |
+
line = line.strip()
|
24 |
+
lines.append(line.split('\t'))
|
25 |
+
line = self.__source_file.readline()
|
26 |
+
|
27 |
+
length = len(lines)
|
28 |
+
if length == 0:
|
29 |
+
return None
|
30 |
+
|
31 |
+
heads = []
|
32 |
+
tokens_dict = {}
|
33 |
+
ids_dict = {}
|
34 |
+
for alphabet_name in self.alphabets.keys():
|
35 |
+
tokens_dict[alphabet_name] = []
|
36 |
+
ids_dict[alphabet_name] = []
|
37 |
+
if symbolic_root:
|
38 |
+
for alphabet_name, alphabet in self.alphabets.items():
|
39 |
+
if alphabet_name.startswith('char'):
|
40 |
+
tokens_dict[alphabet_name].append([ROOT, ])
|
41 |
+
ids_dict[alphabet_name].append([alphabet.get_index(ROOT), ])
|
42 |
+
else:
|
43 |
+
tokens_dict[alphabet_name].append(ROOT)
|
44 |
+
ids_dict[alphabet_name].append(alphabet.get_index(ROOT))
|
45 |
+
heads.append(0)
|
46 |
+
|
47 |
+
for tokens in lines:
|
48 |
+
chars = []
|
49 |
+
char_ids = []
|
50 |
+
if lower_case:
|
51 |
+
tokens[1] = tokens[1].lower()
|
52 |
+
for char in tokens[1]:
|
53 |
+
chars.append(char)
|
54 |
+
char_ids.append(self.alphabets['char_alphabet'].get_index(char))
|
55 |
+
if len(chars) > MAX_CHAR_LENGTH:
|
56 |
+
chars = chars[:MAX_CHAR_LENGTH]
|
57 |
+
char_ids = char_ids[:MAX_CHAR_LENGTH]
|
58 |
+
tokens_dict['char_alphabet'].append(chars)
|
59 |
+
ids_dict['char_alphabet'].append(char_ids)
|
60 |
+
|
61 |
+
word = tokens[1]
|
62 |
+
# print(word+ ' ')
|
63 |
+
pos = tokens[2]
|
64 |
+
ner = tokens[3]
|
65 |
+
head = int(tokens[4])
|
66 |
+
arc_tag = tokens[5]
|
67 |
+
if len(tokens) > 6:
|
68 |
+
auto_label = tokens[6]
|
69 |
+
tokens_dict['auto_label_alphabet'].append(auto_label)
|
70 |
+
ids_dict['auto_label_alphabet'].append(self.alphabets['auto_label_alphabet'].get_index(auto_label))
|
71 |
+
tokens_dict['word_alphabet'].append(word)
|
72 |
+
ids_dict['word_alphabet'].append(self.alphabets['word_alphabet'].get_index(word))
|
73 |
+
tokens_dict['pos_alphabet'].append(pos)
|
74 |
+
ids_dict['pos_alphabet'].append(self.alphabets['pos_alphabet'].get_index(pos))
|
75 |
+
tokens_dict['ner_alphabet'].append(ner)
|
76 |
+
ids_dict['ner_alphabet'].append(self.alphabets['ner_alphabet'].get_index(ner))
|
77 |
+
tokens_dict['arc_alphabet'].append(arc_tag)
|
78 |
+
ids_dict['arc_alphabet'].append(self.alphabets['arc_alphabet'].get_index(arc_tag))
|
79 |
+
heads.append(head)
|
80 |
+
|
81 |
+
if symbolic_end:
|
82 |
+
for alphabet_name, alphabet in self.alphabets.items():
|
83 |
+
if alphabet_name.startswith('char'):
|
84 |
+
tokens_dict[alphabet_name].append([END, ])
|
85 |
+
ids_dict[alphabet_name].append([alphabet.get_index(END), ])
|
86 |
+
else:
|
87 |
+
tokens_dict[alphabet_name] = [END]
|
88 |
+
ids_dict[alphabet_name] = [alphabet.get_index(END)]
|
89 |
+
heads.append(0)
|
90 |
+
|
91 |
+
return NER_DependencyInstance(Sentence(tokens_dict['word_alphabet'], ids_dict['word_alphabet'],
|
92 |
+
tokens_dict['char_alphabet'], ids_dict['char_alphabet']),
|
93 |
+
tokens_dict, ids_dict, heads)
|
utils/io_/rearrange_splits.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch import stack
|
3 |
+
|
4 |
+
def rearranging_splits(datasets, num_training_samples):
|
5 |
+
new_datasets = {}
|
6 |
+
data_splits = datasets.keys()
|
7 |
+
for split in data_splits:
|
8 |
+
if split == 'test':
|
9 |
+
new_datasets['test'] = datasets['test']
|
10 |
+
|
11 |
+
else:
|
12 |
+
num_buckets = len(datasets[split][1])
|
13 |
+
num_tensors = len(datasets[split][0][0])
|
14 |
+
num_samples = sum(datasets[split][1])
|
15 |
+
if num_samples < num_training_samples:
|
16 |
+
print("set_num_training_samples (%d) should be smaller than the actual %s size (%d)"
|
17 |
+
% (num_training_samples, split, num_samples))
|
18 |
+
new_datasets[split] = [[[[] for _ in range(num_tensors)] for _ in range(num_buckets)], []]
|
19 |
+
new_datasets['extra_' + split] = [[[[] for _ in range(num_tensors)] for _ in range(num_buckets)], []]
|
20 |
+
for split in data_splits:
|
21 |
+
if split == 'test':
|
22 |
+
continue
|
23 |
+
else:
|
24 |
+
curr_bucket_sizes = datasets[split][1]
|
25 |
+
curr_samples = datasets[split][0]
|
26 |
+
num_tensors = len(datasets[split][0][0])
|
27 |
+
curr_num_samples = sum(curr_bucket_sizes)
|
28 |
+
sample_indices_in_buckets = {}
|
29 |
+
i = 0
|
30 |
+
for bucket_idx, bucket_size in enumerate(curr_bucket_sizes):
|
31 |
+
for sample_idx in range(bucket_size):
|
32 |
+
sample_indices_in_buckets[i] = (bucket_idx, sample_idx)
|
33 |
+
i += 1
|
34 |
+
rng = np.arange(curr_num_samples)
|
35 |
+
rng = np.random.permutation(rng)
|
36 |
+
sample_indices = {}
|
37 |
+
sample_indices[split] = [sample_indices_in_buckets[key] for key in rng[:num_training_samples]]
|
38 |
+
sample_indices['extra_' + split] = [sample_indices_in_buckets[key] for key in rng[num_training_samples:]]
|
39 |
+
if len(sample_indices['extra_' + split]) == 0:
|
40 |
+
if len(sample_indices[split]) > 1:
|
41 |
+
sample_indices['extra_' + split].append(sample_indices[split].pop(-1))
|
42 |
+
else:
|
43 |
+
sample_indices['extra_' + split].append(sample_indices[split][0])
|
44 |
+
|
45 |
+
for key, indices in sample_indices.items():
|
46 |
+
for bucket_idx, sample_idx in indices:
|
47 |
+
curr_bucket = curr_samples[bucket_idx]
|
48 |
+
for tensor_idx, tensor in enumerate(curr_bucket):
|
49 |
+
new_datasets[key][0][bucket_idx][tensor_idx].append(tensor[sample_idx])
|
50 |
+
del datasets
|
51 |
+
new_splits = []
|
52 |
+
new_splits += [split for split in data_splits if split != 'test']
|
53 |
+
new_splits += ['extra_' + split for split in data_splits if split != 'test']
|
54 |
+
|
55 |
+
for split in new_splits:
|
56 |
+
for bucket_idx in range(num_buckets):
|
57 |
+
for tensor_idx in range(num_tensors):
|
58 |
+
if len(new_datasets[split][0][bucket_idx][tensor_idx]) > 0:
|
59 |
+
new_datasets[split][0][bucket_idx][tensor_idx] = stack(new_datasets[split][0][bucket_idx][tensor_idx])
|
60 |
+
else:
|
61 |
+
new_datasets[split][0][bucket_idx] = (1,1)
|
62 |
+
break
|
63 |
+
# set lengths of buckets
|
64 |
+
if new_datasets[split][0][bucket_idx] == (1,1):
|
65 |
+
new_datasets[split][1].append(0)
|
66 |
+
else:
|
67 |
+
new_datasets[split][1].append(len(new_datasets[split][0][bucket_idx][tensor_idx]))
|
68 |
+
return new_datasets
|
utils/io_/remove_xx.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
def read_file(filename):
|
4 |
+
sentences = []
|
5 |
+
sentence = []
|
6 |
+
lengths = []
|
7 |
+
num_sentneces_to_remove = 0
|
8 |
+
num_sentences = 0
|
9 |
+
num_tokens_to_remove = 0
|
10 |
+
num_tokens = 0
|
11 |
+
with open(filename, 'r') as file:
|
12 |
+
for line in file:
|
13 |
+
line = line.strip()
|
14 |
+
if len(line) == 0:
|
15 |
+
xx_count = 0
|
16 |
+
for row in sentence:
|
17 |
+
if row[2] == 'XX':
|
18 |
+
xx_count += 1
|
19 |
+
if xx_count / len(sentence) >= 0.5:
|
20 |
+
num_sentneces_to_remove += 1
|
21 |
+
num_tokens_to_remove += len(sentence)
|
22 |
+
else:
|
23 |
+
sentences.append(sentence)
|
24 |
+
lengths.append(len(sentence))
|
25 |
+
num_sentences += 1
|
26 |
+
num_tokens += len(sentence)
|
27 |
+
sentence = []
|
28 |
+
continue
|
29 |
+
tokens = line.split('\t')
|
30 |
+
idx = tokens[0]
|
31 |
+
word = tokens[1]
|
32 |
+
pos = tokens[2]
|
33 |
+
ner = tokens[3]
|
34 |
+
arc = tokens[4]
|
35 |
+
arc_tag = tokens[5]
|
36 |
+
sentence.append((idx, word, pos, ner, arc, arc_tag))
|
37 |
+
print("removed %d sentences out of %d sentences" % (num_sentneces_to_remove, num_sentences))
|
38 |
+
print("removed %d tokens out of %d tokens" % (num_tokens_to_remove, num_tokens))
|
39 |
+
return sentences
|
40 |
+
|
41 |
+
def write_file(filename, sentences):
|
42 |
+
with open(filename, 'w') as file:
|
43 |
+
for sentence in sentences:
|
44 |
+
for row in sentence:
|
45 |
+
file.write('\t'.join([token for token in row]) + '\n')
|
46 |
+
file.write('\n')
|
47 |
+
|
48 |
+
dataset_dict = {'ontonotes': 'onto'}
|
49 |
+
datasets = ['ontonotes']
|
50 |
+
splits = ['test']
|
51 |
+
domains = ['all', 'wb']
|
52 |
+
|
53 |
+
for dataset in datasets:
|
54 |
+
for domain in domains:
|
55 |
+
for split in splits:
|
56 |
+
print('dataset: %s, domain: %s, split: %s' % (dataset, domain, split))
|
57 |
+
filemame = 'data/'+ dataset_dict[dataset] + '_pos_ner_dp_' + split + '_' + domain
|
58 |
+
sentences = read_file(filemame)
|
59 |
+
write_filename = filemame + '_without_xx'
|
60 |
+
write_file(write_filename, sentences)
|
utils/io_/seeds.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
seed = 0
|
5 |
+
torch.manual_seed(seed)
|
6 |
+
torch.cuda.manual_seed(seed)
|
7 |
+
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
8 |
+
np.random.seed(seed) # Numpy module.
|
9 |
+
random.seed(seed) # Python random module.
|
10 |
+
torch.manual_seed(seed)
|
11 |
+
torch.backends.cudnn.benchmark = False
|
12 |
+
torch.backends.cudnn.deterministic = True
|
utils/io_/write_extra_labels.py
ADDED
@@ -0,0 +1,1592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .prepare_data import ROOT, END
|
3 |
+
import pdb
|
4 |
+
def get_split(path):
|
5 |
+
if 'train' in path:
|
6 |
+
if 'extra_train' in path:
|
7 |
+
split = 'extra_train'
|
8 |
+
else:
|
9 |
+
split = 'train'
|
10 |
+
elif 'dev' in path:
|
11 |
+
if 'extra_dev' in path:
|
12 |
+
split = 'extra_dev'
|
13 |
+
else:
|
14 |
+
split = 'dev'
|
15 |
+
else:
|
16 |
+
split = 'test'
|
17 |
+
return split
|
18 |
+
|
19 |
+
def add_number_of_children(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
20 |
+
if src_domain == tgt_domain:
|
21 |
+
pred_paths = []
|
22 |
+
if use_unlabeled_data:
|
23 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and 'poetry' not in file and 'prose' not in file and 'extra' in file and tgt_domain in file]
|
24 |
+
|
25 |
+
gold_paths = [file for file in os.listdir(parser_path) if file.endswith("gold.txt") and 'poetry' not in file and 'prose' not in file and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
26 |
+
if use_labeled_data:
|
27 |
+
gold_paths += [file for file in os.listdir(parser_path) if file.endswith("gold.txt") and 'poetry' not in file and 'prose' not in file and 'extra' not in file and tgt_domain in file and 'train' in file]
|
28 |
+
|
29 |
+
if not use_unlabeled_data and not use_labeled_data:
|
30 |
+
raise ValueError
|
31 |
+
else:
|
32 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
33 |
+
|
34 |
+
gold_paths = []
|
35 |
+
if use_labeled_data:
|
36 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
37 |
+
|
38 |
+
if not use_unlabeled_data and not use_labeled_data:
|
39 |
+
raise ValueError
|
40 |
+
|
41 |
+
paths = pred_paths + gold_paths
|
42 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
43 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
|
44 |
+
writing_paths = {}
|
45 |
+
sentences = {}
|
46 |
+
for path in paths:
|
47 |
+
if tgt_domain in path:
|
48 |
+
reading_path = parser_path + path
|
49 |
+
writing_path = model_path + 'parser_' + path
|
50 |
+
split = get_split(writing_path)
|
51 |
+
else:
|
52 |
+
reading_path = path
|
53 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
54 |
+
split = 'extra_train'
|
55 |
+
writing_paths[split] = writing_path
|
56 |
+
len_sent = 0
|
57 |
+
number_of_children = {}
|
58 |
+
lines = []
|
59 |
+
sentences_list = []
|
60 |
+
with open(reading_path, 'r') as file:
|
61 |
+
for line in file:
|
62 |
+
# line = line.decode('utf-8')
|
63 |
+
line = line.strip()
|
64 |
+
# print(line)
|
65 |
+
if len(line) == 0:
|
66 |
+
for idx in range(len_sent):
|
67 |
+
node = str(idx + 1)
|
68 |
+
if node not in number_of_children:
|
69 |
+
lines[idx].append('0')
|
70 |
+
else:
|
71 |
+
lines[idx].append(str(number_of_children[node]))
|
72 |
+
if len(lines) > 0:
|
73 |
+
tmp_root_line = root_line + [str(number_of_children['0'])]
|
74 |
+
sentences_list.append(tmp_root_line)
|
75 |
+
for line_ in lines:
|
76 |
+
sentences_list.append(line_)
|
77 |
+
sentences_list.append([])
|
78 |
+
lines = []
|
79 |
+
number_of_children = {}
|
80 |
+
len_sent = 0
|
81 |
+
continue
|
82 |
+
tokens = line.split('\t')
|
83 |
+
idx = tokens[0]
|
84 |
+
word = tokens[1]
|
85 |
+
pos = tokens[2]
|
86 |
+
ner = tokens[3]
|
87 |
+
head = tokens[4]
|
88 |
+
arc_tag = tokens[5]
|
89 |
+
if head not in number_of_children:
|
90 |
+
number_of_children[head] = 1
|
91 |
+
else:
|
92 |
+
number_of_children[head] += 1
|
93 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
94 |
+
len_sent += 1
|
95 |
+
sentences[split] = sentences_list
|
96 |
+
|
97 |
+
train_sentences = []
|
98 |
+
if 'train' in sentences:
|
99 |
+
train_sentences = sentences['train']
|
100 |
+
else:
|
101 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
102 |
+
if 'extra_train' in sentences:
|
103 |
+
train_sentences += sentences['extra_train']
|
104 |
+
del writing_paths['extra_train']
|
105 |
+
if 'extra_dev' in sentences:
|
106 |
+
train_sentences += sentences['extra_dev']
|
107 |
+
del writing_paths['extra_dev']
|
108 |
+
with open(writing_paths['train'], 'w') as f:
|
109 |
+
for sent in train_sentences:
|
110 |
+
f.write('\t'.join(sent) + '\n')
|
111 |
+
for split in ['dev', 'test']:
|
112 |
+
if split in sentences:
|
113 |
+
split_sentences = sentences[split]
|
114 |
+
with open(writing_paths[split], 'w') as f:
|
115 |
+
for sent in split_sentences:
|
116 |
+
f.write('\t'.join(sent) + '\n')
|
117 |
+
return writing_paths
|
118 |
+
|
119 |
+
|
120 |
+
def add_distance_from_the_root(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
121 |
+
if src_domain == tgt_domain:
|
122 |
+
pred_paths = []
|
123 |
+
if use_unlabeled_data:
|
124 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
125 |
+
file.endswith("pred.txt") and 'poetry' not in file and 'prose' not in file and 'extra' in file and tgt_domain in file]
|
126 |
+
|
127 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
128 |
+
file.endswith("gold.txt") and 'poetry' not in file and 'prose' not in file and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
129 |
+
if use_labeled_data:
|
130 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
131 |
+
file.endswith("gold.txt") and 'extra' not in file and 'poetry' not in file and 'prose' not in file and tgt_domain in file and 'train' in file]
|
132 |
+
|
133 |
+
if not use_unlabeled_data and not use_labeled_data:
|
134 |
+
raise ValueError
|
135 |
+
else:
|
136 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
137 |
+
|
138 |
+
gold_paths = []
|
139 |
+
if use_labeled_data:
|
140 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
141 |
+
|
142 |
+
if not use_unlabeled_data and not use_labeled_data:
|
143 |
+
raise ValueError
|
144 |
+
|
145 |
+
paths = pred_paths + gold_paths
|
146 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
147 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root', '0']
|
148 |
+
writing_paths = {}
|
149 |
+
sentences = {}
|
150 |
+
for path in paths:
|
151 |
+
if tgt_domain in path:
|
152 |
+
reading_path = parser_path + path
|
153 |
+
writing_path = model_path + 'parser_' + path
|
154 |
+
split = get_split(writing_path)
|
155 |
+
else:
|
156 |
+
reading_path = path
|
157 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
158 |
+
split = 'extra_train'
|
159 |
+
writing_paths[split] = writing_path
|
160 |
+
len_sent = 0
|
161 |
+
tree_dict = {'0': '0'}
|
162 |
+
lines = []
|
163 |
+
sentences_list = []
|
164 |
+
with open(reading_path, 'r') as file:
|
165 |
+
for line in file:
|
166 |
+
# line = line.decode('utf-8')
|
167 |
+
line = line.strip()
|
168 |
+
if len(line) == 0:
|
169 |
+
for idx in range(len_sent):
|
170 |
+
depth = 1
|
171 |
+
node = str(idx + 1)
|
172 |
+
while tree_dict[node] != '0':
|
173 |
+
node = tree_dict[node]
|
174 |
+
depth += 1
|
175 |
+
lines[idx].append(str(depth))
|
176 |
+
if len(lines) > 0:
|
177 |
+
sentences_list.append(root_line)
|
178 |
+
for line_ in lines:
|
179 |
+
sentences_list.append(line_)
|
180 |
+
sentences_list.append([])
|
181 |
+
lines = []
|
182 |
+
tree_dict = {'0': '0'}
|
183 |
+
len_sent = 0
|
184 |
+
continue
|
185 |
+
tokens = line.split('\t')
|
186 |
+
idx = tokens[0]
|
187 |
+
word = tokens[1]
|
188 |
+
pos = tokens[2]
|
189 |
+
ner = tokens[3]
|
190 |
+
head = tokens[4]
|
191 |
+
arc_tag = tokens[5]
|
192 |
+
tree_dict[idx] = head
|
193 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
194 |
+
len_sent += 1
|
195 |
+
sentences[split] = sentences_list
|
196 |
+
|
197 |
+
train_sentences = []
|
198 |
+
if 'train' in sentences:
|
199 |
+
train_sentences = sentences['train']
|
200 |
+
else:
|
201 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
202 |
+
if 'extra_train' in sentences:
|
203 |
+
train_sentences += sentences['extra_train']
|
204 |
+
del writing_paths['extra_train']
|
205 |
+
if 'extra_dev' in sentences:
|
206 |
+
train_sentences += sentences['extra_dev']
|
207 |
+
del writing_paths['extra_dev']
|
208 |
+
with open(writing_paths['train'], 'w') as f:
|
209 |
+
for sent in train_sentences:
|
210 |
+
f.write('\t'.join(sent) + '\n')
|
211 |
+
for split in ['dev', 'test']:
|
212 |
+
if split in sentences:
|
213 |
+
split_sentences = sentences[split]
|
214 |
+
with open(writing_paths[split], 'w') as f:
|
215 |
+
for sent in split_sentences:
|
216 |
+
f.write('\t'.join(sent) + '\n')
|
217 |
+
return writing_paths
|
218 |
+
|
219 |
+
def add_relative_pos_based(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
220 |
+
# most of the code for this function is taken from:
|
221 |
+
# https://github.com/mstrise/dep2label/blob/master/encoding.py
|
222 |
+
def pos_cluster(pos):
|
223 |
+
# clustering the parts of speech
|
224 |
+
if pos[0] == 'V':
|
225 |
+
pos = 'VB'
|
226 |
+
elif pos == 'NNS':
|
227 |
+
pos = 'NN'
|
228 |
+
elif pos == 'NNPS':
|
229 |
+
pos = 'NNP'
|
230 |
+
elif 'JJ' in pos:
|
231 |
+
pos = 'JJ'
|
232 |
+
elif pos[:2] == 'RB' or pos == 'WRB' or pos == 'RP':
|
233 |
+
pos = 'RB'
|
234 |
+
elif pos[:3] == 'PRP':
|
235 |
+
pos = 'PRP'
|
236 |
+
elif pos in ['.', ':', ',', "''", '``']:
|
237 |
+
pos = '.'
|
238 |
+
elif pos[0] == '-':
|
239 |
+
pos = '-RB-'
|
240 |
+
elif pos[:2] == 'WP':
|
241 |
+
pos = 'WP'
|
242 |
+
return pos
|
243 |
+
|
244 |
+
if src_domain == tgt_domain:
|
245 |
+
pred_paths = []
|
246 |
+
if use_unlabeled_data:
|
247 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
248 |
+
file.endswith("pred.txt") and 'poetry' not in file and 'prose' not in file and 'extra' in file and tgt_domain in file]
|
249 |
+
|
250 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
251 |
+
file.endswith("gold.txt") and 'poetry' not in file and 'prose' not in file and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
252 |
+
if use_labeled_data:
|
253 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
254 |
+
file.endswith("gold.txt") and 'extra' not in file and 'poetry' not in file and 'prose' not in file and tgt_domain in file and 'train' in file]
|
255 |
+
|
256 |
+
if not use_unlabeled_data and not use_labeled_data:
|
257 |
+
raise ValueError
|
258 |
+
else:
|
259 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
260 |
+
gold_paths = []
|
261 |
+
if use_labeled_data:
|
262 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
263 |
+
|
264 |
+
if not use_unlabeled_data:
|
265 |
+
raise ValueError
|
266 |
+
|
267 |
+
paths = pred_paths + gold_paths
|
268 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
269 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root', '+0_XX']
|
270 |
+
writing_paths = {}
|
271 |
+
sentences = {}
|
272 |
+
for path in paths:
|
273 |
+
if tgt_domain in path:
|
274 |
+
reading_path = parser_path + path
|
275 |
+
writing_path = model_path + 'parser_' + path
|
276 |
+
split = get_split(writing_path)
|
277 |
+
else:
|
278 |
+
reading_path = path
|
279 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
280 |
+
split = 'extra_train'
|
281 |
+
writing_paths[split] = writing_path
|
282 |
+
len_sent = 0
|
283 |
+
tree_dict = {'0': '0'}
|
284 |
+
lines = []
|
285 |
+
sentences_list = []
|
286 |
+
with open(reading_path, 'r') as file:
|
287 |
+
for line in file:
|
288 |
+
# line = line.decode('utf-8')
|
289 |
+
line = line.strip()
|
290 |
+
if len(line) == 0:
|
291 |
+
for idx in range(len_sent):
|
292 |
+
info_of_a_word = lines[idx]
|
293 |
+
# head is on the right side from the word
|
294 |
+
head = int(info_of_a_word[4]) - 1
|
295 |
+
if head == -1:
|
296 |
+
info_about_head = root_line
|
297 |
+
else:
|
298 |
+
info_about_head = lines[head]
|
299 |
+
if idx < head:
|
300 |
+
relative_position_head = 1
|
301 |
+
postag_head = pos_cluster(info_about_head[2])
|
302 |
+
|
303 |
+
for x in range(idx + 1, head):
|
304 |
+
another_word = lines[x]
|
305 |
+
postag_word_before_head = pos_cluster(another_word[2])
|
306 |
+
if postag_word_before_head == postag_head:
|
307 |
+
relative_position_head += 1
|
308 |
+
label = str(
|
309 |
+
"+" +
|
310 |
+
repr(relative_position_head) +
|
311 |
+
"_" +
|
312 |
+
postag_head)
|
313 |
+
lines[idx].append(label)
|
314 |
+
|
315 |
+
# head is on the left side from the word
|
316 |
+
elif idx > head:
|
317 |
+
relative_position_head = 1
|
318 |
+
postag_head = pos_cluster(info_about_head[2])
|
319 |
+
for x in range(head + 1, idx):
|
320 |
+
another_word = lines[x]
|
321 |
+
postag_word_before_head = pos_cluster(another_word[2])
|
322 |
+
if postag_word_before_head == postag_head:
|
323 |
+
relative_position_head += 1
|
324 |
+
label = str(
|
325 |
+
"-" +
|
326 |
+
repr(relative_position_head) +
|
327 |
+
"_" +
|
328 |
+
postag_head)
|
329 |
+
lines[idx].append(label)
|
330 |
+
if len(lines) > 0:
|
331 |
+
sentences_list.append(root_line)
|
332 |
+
for line_ in lines:
|
333 |
+
sentences_list.append(line_)
|
334 |
+
sentences_list.append([])
|
335 |
+
lines = []
|
336 |
+
tree_dict = {'0': '0'}
|
337 |
+
len_sent = 0
|
338 |
+
continue
|
339 |
+
tokens = line.split('\t')
|
340 |
+
idx = tokens[0]
|
341 |
+
word = tokens[1]
|
342 |
+
pos = tokens[2]
|
343 |
+
ner = tokens[3]
|
344 |
+
head = tokens[4]
|
345 |
+
arc_tag = tokens[5]
|
346 |
+
tree_dict[idx] = head
|
347 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
348 |
+
len_sent += 1
|
349 |
+
sentences[split] = sentences_list
|
350 |
+
|
351 |
+
train_sentences = []
|
352 |
+
if 'train' in sentences:
|
353 |
+
train_sentences = sentences['train']
|
354 |
+
else:
|
355 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
356 |
+
if 'extra_train' in sentences:
|
357 |
+
train_sentences += sentences['extra_train']
|
358 |
+
del writing_paths['extra_train']
|
359 |
+
if 'extra_dev' in sentences:
|
360 |
+
train_sentences += sentences['extra_dev']
|
361 |
+
del writing_paths['extra_dev']
|
362 |
+
with open(writing_paths['train'], 'w') as f:
|
363 |
+
for sent in train_sentences:
|
364 |
+
f.write('\t'.join(sent) + '\n')
|
365 |
+
for split in ['dev', 'test']:
|
366 |
+
if split in sentences:
|
367 |
+
split_sentences = sentences[split]
|
368 |
+
with open(writing_paths[split], 'w') as f:
|
369 |
+
for sent in split_sentences:
|
370 |
+
f.write('\t'.join(sent) + '\n')
|
371 |
+
return writing_paths
|
372 |
+
|
373 |
+
def add_language_model(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
374 |
+
if src_domain == tgt_domain:
|
375 |
+
pred_paths = []
|
376 |
+
if use_unlabeled_data:
|
377 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
378 |
+
file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
|
379 |
+
|
380 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
381 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
382 |
+
if use_labeled_data:
|
383 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
384 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
|
385 |
+
|
386 |
+
if not use_unlabeled_data and not use_labeled_data:
|
387 |
+
raise ValueError
|
388 |
+
else:
|
389 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
390 |
+
|
391 |
+
gold_paths = []
|
392 |
+
if use_labeled_data:
|
393 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
394 |
+
|
395 |
+
if not use_unlabeled_data and not use_labeled_data:
|
396 |
+
raise ValueError
|
397 |
+
|
398 |
+
paths = pred_paths + gold_paths
|
399 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
400 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
|
401 |
+
writing_paths = {}
|
402 |
+
sentences = {}
|
403 |
+
for path in paths:
|
404 |
+
if tgt_domain in path:
|
405 |
+
reading_path = parser_path + path
|
406 |
+
writing_path = model_path + 'parser_' + path
|
407 |
+
split = get_split(writing_path)
|
408 |
+
else:
|
409 |
+
reading_path = path
|
410 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
411 |
+
split = 'extra_train'
|
412 |
+
writing_paths[split] = writing_path
|
413 |
+
len_sent = 0
|
414 |
+
lines = []
|
415 |
+
sentences_list = []
|
416 |
+
with open(reading_path, 'r') as file:
|
417 |
+
for line in file:
|
418 |
+
# line = line.decode('utf-8')
|
419 |
+
line = line.strip()
|
420 |
+
if len(line) == 0:
|
421 |
+
for idx in range(len_sent):
|
422 |
+
if idx < len_sent - 1:
|
423 |
+
lines[idx].append(lines[idx+1][1])
|
424 |
+
else:
|
425 |
+
lines[idx].append(END)
|
426 |
+
if len(lines) > 0:
|
427 |
+
tmp_root_line = root_line + [lines[0][1]]
|
428 |
+
sentences_list.append(tmp_root_line)
|
429 |
+
for line_ in lines:
|
430 |
+
sentences_list.append(line_)
|
431 |
+
sentences_list.append([])
|
432 |
+
lines = []
|
433 |
+
len_sent = 0
|
434 |
+
continue
|
435 |
+
tokens = line.split('\t')
|
436 |
+
idx = tokens[0]
|
437 |
+
word = tokens[1]
|
438 |
+
pos = tokens[2]
|
439 |
+
ner = tokens[3]
|
440 |
+
head = tokens[4]
|
441 |
+
arc_tag = tokens[5]
|
442 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
443 |
+
len_sent += 1
|
444 |
+
sentences[split] = sentences_list
|
445 |
+
|
446 |
+
train_sentences = []
|
447 |
+
if 'train' in sentences:
|
448 |
+
train_sentences = sentences['train']
|
449 |
+
else:
|
450 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
451 |
+
if 'extra_train' in sentences:
|
452 |
+
train_sentences += sentences['extra_train']
|
453 |
+
del writing_paths['extra_train']
|
454 |
+
if 'extra_dev' in sentences:
|
455 |
+
train_sentences += sentences['extra_dev']
|
456 |
+
del writing_paths['extra_dev']
|
457 |
+
with open(writing_paths['train'], 'w') as f:
|
458 |
+
for sent in train_sentences:
|
459 |
+
f.write('\t'.join(sent) + '\n')
|
460 |
+
for split in ['dev', 'test']:
|
461 |
+
if split in sentences:
|
462 |
+
split_sentences = sentences[split]
|
463 |
+
with open(writing_paths[split], 'w') as f:
|
464 |
+
for sent in split_sentences:
|
465 |
+
f.write('\t'.join(sent) + '\n')
|
466 |
+
return writing_paths
|
467 |
+
|
468 |
+
def add_relative_TAG(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
469 |
+
# most of the code for this function is taken from:
|
470 |
+
# https://github.com/mstrise/dep2label/blob/master/encoding.py
|
471 |
+
def pos_cluster(pos):
|
472 |
+
# clustering the parts of speech
|
473 |
+
if pos[0] == 'V':
|
474 |
+
pos = 'VB'
|
475 |
+
elif pos == 'NNS':
|
476 |
+
pos = 'NN'
|
477 |
+
elif pos == 'NNPS':
|
478 |
+
pos = 'NNP'
|
479 |
+
elif 'JJ' in pos:
|
480 |
+
pos = 'JJ'
|
481 |
+
elif pos[:2] == 'RB' or pos == 'WRB' or pos == 'RP':
|
482 |
+
pos = 'RB'
|
483 |
+
elif pos[:3] == 'PRP':
|
484 |
+
pos = 'PRP'
|
485 |
+
elif pos in ['.', ':', ',', "''", '``']:
|
486 |
+
pos = '.'
|
487 |
+
elif pos[0] == '-':
|
488 |
+
pos = '-RB-'
|
489 |
+
elif pos[:2] == 'WP':
|
490 |
+
pos = 'WP'
|
491 |
+
return pos
|
492 |
+
|
493 |
+
if src_domain == tgt_domain:
|
494 |
+
pred_paths = []
|
495 |
+
if use_unlabeled_data:
|
496 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
497 |
+
file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
|
498 |
+
|
499 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
500 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
501 |
+
if use_labeled_data:
|
502 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
503 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
|
504 |
+
|
505 |
+
if not use_unlabeled_data and not use_labeled_data:
|
506 |
+
raise ValueError
|
507 |
+
else:
|
508 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
509 |
+
gold_paths = []
|
510 |
+
if use_labeled_data:
|
511 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
512 |
+
|
513 |
+
if not use_unlabeled_data:
|
514 |
+
raise ValueError
|
515 |
+
|
516 |
+
paths = pred_paths + gold_paths
|
517 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
518 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root', '+0_XX']
|
519 |
+
writing_paths = {}
|
520 |
+
sentences = {}
|
521 |
+
for path in paths:
|
522 |
+
if tgt_domain in path:
|
523 |
+
reading_path = parser_path + path
|
524 |
+
writing_path = model_path + 'parser_' + path
|
525 |
+
split = get_split(writing_path)
|
526 |
+
else:
|
527 |
+
reading_path = path
|
528 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
529 |
+
split = 'extra_train'
|
530 |
+
writing_paths[split] = writing_path
|
531 |
+
len_sent = 0
|
532 |
+
tree_dict = {'0': '0'}
|
533 |
+
lines = []
|
534 |
+
sentences_list = []
|
535 |
+
with open(reading_path, 'r') as file:
|
536 |
+
|
537 |
+
for line in file:
|
538 |
+
# print(line)
|
539 |
+
# print(reading_path)
|
540 |
+
# line = line.decode('utf-8')
|
541 |
+
line = line.strip()
|
542 |
+
if len(line) == 0:
|
543 |
+
for idx in range(len_sent):
|
544 |
+
info_of_a_word = lines[idx]
|
545 |
+
# head is on the right side from the word
|
546 |
+
head = int(info_of_a_word[4]) - 1
|
547 |
+
if head == -1:
|
548 |
+
info_about_head = root_line
|
549 |
+
else:
|
550 |
+
# print(len(lines), head)
|
551 |
+
info_about_head = lines[head]
|
552 |
+
|
553 |
+
if idx < head:
|
554 |
+
relative_position_head = 1
|
555 |
+
tag_head = info_about_head[5]
|
556 |
+
|
557 |
+
for x in range(idx + 1, head):
|
558 |
+
another_word = lines[x]
|
559 |
+
tag_word_before_head = another_word[5]
|
560 |
+
if tag_word_before_head == tag_head:
|
561 |
+
relative_position_head += 1
|
562 |
+
label = str(
|
563 |
+
"+" +
|
564 |
+
repr(relative_position_head) +
|
565 |
+
"_" +
|
566 |
+
tag_head)
|
567 |
+
lines[idx].append(label)
|
568 |
+
|
569 |
+
# head is on the left side from the word
|
570 |
+
elif idx > head:
|
571 |
+
relative_position_head = 1
|
572 |
+
tag_head = info_about_head[5]
|
573 |
+
for x in range(head + 1, idx):
|
574 |
+
another_word = lines[x]
|
575 |
+
tag_word_before_head = another_word[5]
|
576 |
+
if tag_word_before_head == tag_head:
|
577 |
+
relative_position_head += 1
|
578 |
+
label = str(
|
579 |
+
"-" +
|
580 |
+
repr(relative_position_head) +
|
581 |
+
"_" +
|
582 |
+
tag_head)
|
583 |
+
lines[idx].append(label)
|
584 |
+
if len(lines) > 0:
|
585 |
+
sentences_list.append(root_line)
|
586 |
+
for line_ in lines:
|
587 |
+
sentences_list.append(line_)
|
588 |
+
sentences_list.append([])
|
589 |
+
lines = []
|
590 |
+
tree_dict = {'0': '0'}
|
591 |
+
len_sent = 0
|
592 |
+
continue
|
593 |
+
tokens = line.split('\t')
|
594 |
+
idx = tokens[0]
|
595 |
+
word = tokens[1]
|
596 |
+
pos = tokens[2]
|
597 |
+
ner = tokens[3]
|
598 |
+
head = tokens[4]
|
599 |
+
arc_tag = tokens[5]
|
600 |
+
tree_dict[idx] = head
|
601 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
602 |
+
len_sent += 1
|
603 |
+
sentences[split] = sentences_list
|
604 |
+
|
605 |
+
train_sentences = []
|
606 |
+
if 'train' in sentences:
|
607 |
+
train_sentences = sentences['train']
|
608 |
+
else:
|
609 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
610 |
+
if 'extra_train' in sentences:
|
611 |
+
train_sentences += sentences['extra_train']
|
612 |
+
del writing_paths['extra_train']
|
613 |
+
if 'extra_dev' in sentences:
|
614 |
+
train_sentences += sentences['extra_dev']
|
615 |
+
del writing_paths['extra_dev']
|
616 |
+
with open(writing_paths['train'], 'w') as f:
|
617 |
+
for sent in train_sentences:
|
618 |
+
f.write('\t'.join(sent) + '\n')
|
619 |
+
for split in ['dev', 'test']:
|
620 |
+
if split in sentences:
|
621 |
+
split_sentences = sentences[split]
|
622 |
+
with open(writing_paths[split], 'w') as f:
|
623 |
+
for sent in split_sentences:
|
624 |
+
f.write('\t'.join(sent) + '\n')
|
625 |
+
return writing_paths
|
626 |
+
|
627 |
+
|
628 |
+
def add_head(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
629 |
+
# most of the code for this function is taken from:
|
630 |
+
# https://github.com/mstrise/dep2label/blob/master/encoding.py
|
631 |
+
def pos_cluster(pos):
|
632 |
+
# clustering the parts of speech
|
633 |
+
if pos[0] == 'V':
|
634 |
+
pos = 'VB'
|
635 |
+
elif pos == 'NNS':
|
636 |
+
pos = 'NN'
|
637 |
+
elif pos == 'NNPS':
|
638 |
+
pos = 'NNP'
|
639 |
+
elif 'JJ' in pos:
|
640 |
+
pos = 'JJ'
|
641 |
+
elif pos[:2] == 'RB' or pos == 'WRB' or pos == 'RP':
|
642 |
+
pos = 'RB'
|
643 |
+
elif pos[:3] == 'PRP':
|
644 |
+
pos = 'PRP'
|
645 |
+
elif pos in ['.', ':', ',', "''", '``']:
|
646 |
+
pos = '.'
|
647 |
+
elif pos[0] == '-':
|
648 |
+
pos = '-RB-'
|
649 |
+
elif pos[:2] == 'WP':
|
650 |
+
pos = 'WP'
|
651 |
+
return pos
|
652 |
+
|
653 |
+
if src_domain == tgt_domain:
|
654 |
+
pred_paths = []
|
655 |
+
if use_unlabeled_data:
|
656 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
657 |
+
file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
|
658 |
+
|
659 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
660 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
661 |
+
if use_labeled_data:
|
662 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
663 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
|
664 |
+
|
665 |
+
if not use_unlabeled_data and not use_labeled_data:
|
666 |
+
raise ValueError
|
667 |
+
else:
|
668 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
669 |
+
gold_paths = []
|
670 |
+
if use_labeled_data:
|
671 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
672 |
+
|
673 |
+
if not use_unlabeled_data:
|
674 |
+
raise ValueError
|
675 |
+
|
676 |
+
paths = pred_paths + gold_paths
|
677 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
678 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root', '+0_XX']
|
679 |
+
writing_paths = {}
|
680 |
+
sentences = {}
|
681 |
+
for path in paths:
|
682 |
+
if tgt_domain in path:
|
683 |
+
reading_path = parser_path + path
|
684 |
+
writing_path = model_path + 'parser_' + path
|
685 |
+
split = get_split(writing_path)
|
686 |
+
else:
|
687 |
+
reading_path = path
|
688 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
689 |
+
split = 'extra_train'
|
690 |
+
writing_paths[split] = writing_path
|
691 |
+
len_sent = 0
|
692 |
+
tree_dict = {'0': '0'}
|
693 |
+
lines = []
|
694 |
+
sentences_list = []
|
695 |
+
with open(reading_path, 'r') as file:
|
696 |
+
|
697 |
+
for line in file:
|
698 |
+
# print(line)
|
699 |
+
# print(reading_path)
|
700 |
+
# line = line.decode('utf-8')
|
701 |
+
line = line.strip()
|
702 |
+
if len(line) == 0:
|
703 |
+
for idx in range(len_sent):
|
704 |
+
info_of_a_word = lines[idx]
|
705 |
+
# head is on the right side from the word
|
706 |
+
head = int(info_of_a_word[4]) - 1
|
707 |
+
if head == -1:
|
708 |
+
info_about_head = root_line
|
709 |
+
else:
|
710 |
+
# print(len(lines), head)
|
711 |
+
info_about_head = lines[head]
|
712 |
+
head_word = info_about_head[1]
|
713 |
+
lines[idx].append(head_word)
|
714 |
+
# if idx < head:
|
715 |
+
# relative_position_head = 1
|
716 |
+
|
717 |
+
|
718 |
+
# for x in range(idx + 1, head):
|
719 |
+
# another_word = lines[x]
|
720 |
+
# postag_word_before_head = pos_cluster(another_word[2])
|
721 |
+
# if postag_word_before_head == postag_head:
|
722 |
+
# relative_position_head += 1
|
723 |
+
# label = str(
|
724 |
+
# "+" +
|
725 |
+
# repr(relative_position_head) +
|
726 |
+
# "_" +
|
727 |
+
# postag_head)
|
728 |
+
|
729 |
+
|
730 |
+
# # head is on the left side from the word
|
731 |
+
# elif idx > head:
|
732 |
+
# relative_position_head = 1
|
733 |
+
# postag_head = pos_cluster(info_about_head[2])
|
734 |
+
# for x in range(head + 1, idx):
|
735 |
+
# another_word = lines[x]
|
736 |
+
# postag_word_before_head = pos_cluster(another_word[2])
|
737 |
+
# if postag_word_before_head == postag_head:
|
738 |
+
# relative_position_head += 1
|
739 |
+
# label = str(
|
740 |
+
# "-" +
|
741 |
+
# repr(relative_position_head) +
|
742 |
+
# "_" +
|
743 |
+
# postag_head)
|
744 |
+
# lines[idx].append(label)
|
745 |
+
if len(lines) > 0:
|
746 |
+
sentences_list.append(root_line)
|
747 |
+
for line_ in lines:
|
748 |
+
sentences_list.append(line_)
|
749 |
+
sentences_list.append([])
|
750 |
+
lines = []
|
751 |
+
tree_dict = {'0': '0'}
|
752 |
+
len_sent = 0
|
753 |
+
continue
|
754 |
+
tokens = line.split('\t')
|
755 |
+
idx = tokens[0]
|
756 |
+
word = tokens[1]
|
757 |
+
pos = tokens[2]
|
758 |
+
ner = tokens[3]
|
759 |
+
head = tokens[4]
|
760 |
+
arc_tag = tokens[5]
|
761 |
+
tree_dict[idx] = head
|
762 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
763 |
+
len_sent += 1
|
764 |
+
sentences[split] = sentences_list
|
765 |
+
|
766 |
+
train_sentences = []
|
767 |
+
if 'train' in sentences:
|
768 |
+
train_sentences = sentences['train']
|
769 |
+
else:
|
770 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
771 |
+
if 'extra_train' in sentences:
|
772 |
+
train_sentences += sentences['extra_train']
|
773 |
+
del writing_paths['extra_train']
|
774 |
+
if 'extra_dev' in sentences:
|
775 |
+
train_sentences += sentences['extra_dev']
|
776 |
+
del writing_paths['extra_dev']
|
777 |
+
with open(writing_paths['train'], 'w') as f:
|
778 |
+
for sent in train_sentences:
|
779 |
+
f.write('\t'.join(sent) + '\n')
|
780 |
+
for split in ['dev', 'test']:
|
781 |
+
if split in sentences:
|
782 |
+
split_sentences = sentences[split]
|
783 |
+
with open(writing_paths[split], 'w') as f:
|
784 |
+
for sent in split_sentences:
|
785 |
+
f.write('\t'.join(sent) + '\n')
|
786 |
+
return writing_paths
|
787 |
+
import json
|
788 |
+
def get_modified_coarse(ma):
|
789 |
+
ma = ma.replace('sgpl','sg').replace('sgdu','sg')
|
790 |
+
with open('/home/jivnesh/DCST_scratch/utils/io_/coarse_to_ma_dict.json', 'r') as fh:
|
791 |
+
coarse_dict = json.load(fh)
|
792 |
+
for key in coarse_dict.keys():
|
793 |
+
if ma in coarse_dict[key]:
|
794 |
+
return key
|
795 |
+
def add_head_coarse_pos(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
796 |
+
# most of the code for this function is taken from:
|
797 |
+
# https://github.com/mstrise/dep2label/blob/master/encoding.py
|
798 |
+
def pos_cluster(pos):
|
799 |
+
# clustering the parts of speech
|
800 |
+
if pos[0] == 'V':
|
801 |
+
pos = 'VB'
|
802 |
+
elif pos == 'NNS':
|
803 |
+
pos = 'NN'
|
804 |
+
elif pos == 'NNPS':
|
805 |
+
pos = 'NNP'
|
806 |
+
elif 'JJ' in pos:
|
807 |
+
pos = 'JJ'
|
808 |
+
elif pos[:2] == 'RB' or pos == 'WRB' or pos == 'RP':
|
809 |
+
pos = 'RB'
|
810 |
+
elif pos[:3] == 'PRP':
|
811 |
+
pos = 'PRP'
|
812 |
+
elif pos in ['.', ':', ',', "''", '``']:
|
813 |
+
pos = '.'
|
814 |
+
elif pos[0] == '-':
|
815 |
+
pos = '-RB-'
|
816 |
+
elif pos[:2] == 'WP':
|
817 |
+
pos = 'WP'
|
818 |
+
return pos
|
819 |
+
|
820 |
+
if src_domain == tgt_domain:
|
821 |
+
pred_paths = []
|
822 |
+
if use_unlabeled_data:
|
823 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
824 |
+
file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
|
825 |
+
|
826 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
827 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
828 |
+
if use_labeled_data:
|
829 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
830 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
|
831 |
+
|
832 |
+
if not use_unlabeled_data and not use_labeled_data:
|
833 |
+
raise ValueError
|
834 |
+
else:
|
835 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
836 |
+
gold_paths = []
|
837 |
+
if use_labeled_data:
|
838 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
839 |
+
|
840 |
+
if not use_unlabeled_data:
|
841 |
+
raise ValueError
|
842 |
+
|
843 |
+
paths = pred_paths + gold_paths
|
844 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
845 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root', 'O']
|
846 |
+
writing_paths = {}
|
847 |
+
sentences = {}
|
848 |
+
for path in paths:
|
849 |
+
if tgt_domain in path:
|
850 |
+
reading_path = parser_path + path
|
851 |
+
writing_path = model_path + 'parser_' + path
|
852 |
+
split = get_split(writing_path)
|
853 |
+
else:
|
854 |
+
reading_path = path
|
855 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
856 |
+
split = 'extra_train'
|
857 |
+
writing_paths[split] = writing_path
|
858 |
+
len_sent = 0
|
859 |
+
tree_dict = {'0': '0'}
|
860 |
+
lines = []
|
861 |
+
sentences_list = []
|
862 |
+
with open(reading_path, 'r') as file:
|
863 |
+
|
864 |
+
for line in file:
|
865 |
+
# print(line)
|
866 |
+
# print(reading_path)
|
867 |
+
# line = line.decode('utf-8')
|
868 |
+
line = line.strip()
|
869 |
+
if len(line) == 0:
|
870 |
+
for idx in range(len_sent):
|
871 |
+
info_of_a_word = lines[idx]
|
872 |
+
# head is on the right side from the word
|
873 |
+
head = int(info_of_a_word[4]) - 1
|
874 |
+
if head == -1:
|
875 |
+
info_about_head = root_line
|
876 |
+
else:
|
877 |
+
# print(len(lines), head)
|
878 |
+
info_about_head = lines[head]
|
879 |
+
postag_head = info_about_head[2]
|
880 |
+
lines[idx].append(postag_head)
|
881 |
+
if len(lines) > 0:
|
882 |
+
sentences_list.append(root_line)
|
883 |
+
for line_ in lines:
|
884 |
+
sentences_list.append(line_)
|
885 |
+
sentences_list.append([])
|
886 |
+
lines = []
|
887 |
+
tree_dict = {'0': '0'}
|
888 |
+
len_sent = 0
|
889 |
+
continue
|
890 |
+
tokens = line.split('\t')
|
891 |
+
idx = tokens[0]
|
892 |
+
word = tokens[1]
|
893 |
+
pos = tokens[2]
|
894 |
+
ner = tokens[3]
|
895 |
+
head = tokens[4]
|
896 |
+
arc_tag = tokens[5]
|
897 |
+
tree_dict[idx] = head
|
898 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
899 |
+
len_sent += 1
|
900 |
+
sentences[split] = sentences_list
|
901 |
+
|
902 |
+
train_sentences = []
|
903 |
+
if 'train' in sentences:
|
904 |
+
train_sentences = sentences['train']
|
905 |
+
else:
|
906 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
907 |
+
if 'extra_train' in sentences:
|
908 |
+
train_sentences += sentences['extra_train']
|
909 |
+
del writing_paths['extra_train']
|
910 |
+
if 'extra_dev' in sentences:
|
911 |
+
train_sentences += sentences['extra_dev']
|
912 |
+
del writing_paths['extra_dev']
|
913 |
+
with open(writing_paths['train'], 'w') as f:
|
914 |
+
for sent in train_sentences:
|
915 |
+
f.write('\t'.join(sent) + '\n')
|
916 |
+
for split in ['dev', 'test']:
|
917 |
+
if split in sentences:
|
918 |
+
split_sentences = sentences[split]
|
919 |
+
with open(writing_paths[split], 'w') as f:
|
920 |
+
for sent in split_sentences:
|
921 |
+
f.write('\t'.join(sent) + '\n')
|
922 |
+
return writing_paths
|
923 |
+
|
924 |
+
def add_head_ma(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
925 |
+
# most of the code for this function is taken from:
|
926 |
+
# https://github.com/mstrise/dep2label/blob/master/encoding.py
|
927 |
+
def pos_cluster(pos):
|
928 |
+
# clustering the parts of speech
|
929 |
+
if pos[0] == 'V':
|
930 |
+
pos = 'VB'
|
931 |
+
elif pos == 'NNS':
|
932 |
+
pos = 'NN'
|
933 |
+
elif pos == 'NNPS':
|
934 |
+
pos = 'NNP'
|
935 |
+
elif 'JJ' in pos:
|
936 |
+
pos = 'JJ'
|
937 |
+
elif pos[:2] == 'RB' or pos == 'WRB' or pos == 'RP':
|
938 |
+
pos = 'RB'
|
939 |
+
elif pos[:3] == 'PRP':
|
940 |
+
pos = 'PRP'
|
941 |
+
elif pos in ['.', ':', ',', "''", '``']:
|
942 |
+
pos = '.'
|
943 |
+
elif pos[0] == '-':
|
944 |
+
pos = '-RB-'
|
945 |
+
elif pos[:2] == 'WP':
|
946 |
+
pos = 'WP'
|
947 |
+
return pos
|
948 |
+
|
949 |
+
if src_domain == tgt_domain:
|
950 |
+
pred_paths = []
|
951 |
+
if use_unlabeled_data:
|
952 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
953 |
+
file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
|
954 |
+
|
955 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
956 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
957 |
+
if use_labeled_data:
|
958 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
959 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
|
960 |
+
|
961 |
+
if not use_unlabeled_data and not use_labeled_data:
|
962 |
+
raise ValueError
|
963 |
+
else:
|
964 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
965 |
+
gold_paths = []
|
966 |
+
if use_labeled_data:
|
967 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
968 |
+
|
969 |
+
if not use_unlabeled_data:
|
970 |
+
raise ValueError
|
971 |
+
|
972 |
+
paths = pred_paths + gold_paths
|
973 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
974 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root', 'XX']
|
975 |
+
writing_paths = {}
|
976 |
+
sentences = {}
|
977 |
+
for path in paths:
|
978 |
+
if tgt_domain in path:
|
979 |
+
reading_path = parser_path + path
|
980 |
+
writing_path = model_path + 'parser_' + path
|
981 |
+
split = get_split(writing_path)
|
982 |
+
else:
|
983 |
+
reading_path = path
|
984 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
985 |
+
split = 'extra_train'
|
986 |
+
writing_paths[split] = writing_path
|
987 |
+
len_sent = 0
|
988 |
+
tree_dict = {'0': '0'}
|
989 |
+
lines = []
|
990 |
+
sentences_list = []
|
991 |
+
with open(reading_path, 'r') as file:
|
992 |
+
|
993 |
+
for line in file:
|
994 |
+
# print(line)
|
995 |
+
# print(reading_path)
|
996 |
+
# line = line.decode('utf-8')
|
997 |
+
line = line.strip()
|
998 |
+
if len(line) == 0:
|
999 |
+
for idx in range(len_sent):
|
1000 |
+
info_of_a_word = lines[idx]
|
1001 |
+
# head is on the right side from the word
|
1002 |
+
head = int(info_of_a_word[4]) - 1
|
1003 |
+
if head == -1:
|
1004 |
+
info_about_head = root_line
|
1005 |
+
else:
|
1006 |
+
# print(len(lines), head)
|
1007 |
+
info_about_head = lines[head]
|
1008 |
+
postag_head = pos_cluster(info_about_head[2])
|
1009 |
+
lines[idx].append(postag_head)
|
1010 |
+
if len(lines) > 0:
|
1011 |
+
sentences_list.append(root_line)
|
1012 |
+
for line_ in lines:
|
1013 |
+
sentences_list.append(line_)
|
1014 |
+
sentences_list.append([])
|
1015 |
+
lines = []
|
1016 |
+
tree_dict = {'0': '0'}
|
1017 |
+
len_sent = 0
|
1018 |
+
continue
|
1019 |
+
tokens = line.split('\t')
|
1020 |
+
idx = tokens[0]
|
1021 |
+
word = tokens[1]
|
1022 |
+
pos = tokens[2]
|
1023 |
+
ner = tokens[3]
|
1024 |
+
head = tokens[4]
|
1025 |
+
arc_tag = tokens[5]
|
1026 |
+
tree_dict[idx] = head
|
1027 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
1028 |
+
len_sent += 1
|
1029 |
+
sentences[split] = sentences_list
|
1030 |
+
|
1031 |
+
train_sentences = []
|
1032 |
+
if 'train' in sentences:
|
1033 |
+
train_sentences = sentences['train']
|
1034 |
+
else:
|
1035 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
1036 |
+
if 'extra_train' in sentences:
|
1037 |
+
train_sentences += sentences['extra_train']
|
1038 |
+
del writing_paths['extra_train']
|
1039 |
+
if 'extra_dev' in sentences:
|
1040 |
+
train_sentences += sentences['extra_dev']
|
1041 |
+
del writing_paths['extra_dev']
|
1042 |
+
with open(writing_paths['train'], 'w') as f:
|
1043 |
+
for sent in train_sentences:
|
1044 |
+
f.write('\t'.join(sent) + '\n')
|
1045 |
+
for split in ['dev', 'test']:
|
1046 |
+
if split in sentences:
|
1047 |
+
split_sentences = sentences[split]
|
1048 |
+
with open(writing_paths[split], 'w') as f:
|
1049 |
+
for sent in split_sentences:
|
1050 |
+
f.write('\t'.join(sent) + '\n')
|
1051 |
+
return writing_paths
|
1052 |
+
|
1053 |
+
def add_label(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1054 |
+
if src_domain == tgt_domain:
|
1055 |
+
pred_paths = []
|
1056 |
+
if use_unlabeled_data:
|
1057 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
1058 |
+
file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
|
1059 |
+
|
1060 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
1061 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
1062 |
+
if use_labeled_data:
|
1063 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
1064 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
|
1065 |
+
|
1066 |
+
if not use_unlabeled_data and not use_labeled_data:
|
1067 |
+
raise ValueError
|
1068 |
+
else:
|
1069 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
1070 |
+
|
1071 |
+
gold_paths = []
|
1072 |
+
if use_labeled_data:
|
1073 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
1074 |
+
|
1075 |
+
if not use_unlabeled_data and not use_labeled_data:
|
1076 |
+
raise ValueError
|
1077 |
+
|
1078 |
+
paths = pred_paths + gold_paths
|
1079 |
+
print('############ Add Label Task #################')
|
1080 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
1081 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
|
1082 |
+
writing_paths = {}
|
1083 |
+
sentences = {}
|
1084 |
+
for path in paths:
|
1085 |
+
if tgt_domain in path:
|
1086 |
+
reading_path = parser_path + path
|
1087 |
+
writing_path = model_path + 'parser_' + path
|
1088 |
+
split = get_split(writing_path)
|
1089 |
+
else:
|
1090 |
+
reading_path = path
|
1091 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
1092 |
+
split = 'extra_train'
|
1093 |
+
writing_paths[split] = writing_path
|
1094 |
+
len_sent = 0
|
1095 |
+
lines = []
|
1096 |
+
sentences_list = []
|
1097 |
+
with open(reading_path, 'r') as file:
|
1098 |
+
for line in file:
|
1099 |
+
# line = line.decode('utf-8')
|
1100 |
+
line = line.strip()
|
1101 |
+
# Now blank space got detected
|
1102 |
+
if len(line) == 0:
|
1103 |
+
# Append next word to last column
|
1104 |
+
for idx in range(len_sent):
|
1105 |
+
lines[idx].append(lines[idx][5])
|
1106 |
+
# Add root line first
|
1107 |
+
if len(lines) > 0:
|
1108 |
+
tmp_root_line = root_line + [root_line[5]]
|
1109 |
+
sentences_list.append(tmp_root_line)
|
1110 |
+
for line_ in lines:
|
1111 |
+
sentences_list.append(line_)
|
1112 |
+
sentences_list.append([])
|
1113 |
+
lines = []
|
1114 |
+
len_sent = 0
|
1115 |
+
continue
|
1116 |
+
tokens = line.split('\t')
|
1117 |
+
idx = tokens[0]
|
1118 |
+
word = tokens[1]
|
1119 |
+
pos = tokens[2]
|
1120 |
+
ner = tokens[3]
|
1121 |
+
head = tokens[4]
|
1122 |
+
arc_tag = tokens[5]
|
1123 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
1124 |
+
len_sent += 1
|
1125 |
+
sentences[split] = sentences_list
|
1126 |
+
|
1127 |
+
train_sentences = []
|
1128 |
+
if 'train' in sentences:
|
1129 |
+
train_sentences = sentences['train']
|
1130 |
+
else:
|
1131 |
+
# pdb.set_trace()
|
1132 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
1133 |
+
if 'extra_train' in sentences:
|
1134 |
+
train_sentences += sentences['extra_train']
|
1135 |
+
del writing_paths['extra_train']
|
1136 |
+
if 'extra_dev' in sentences:
|
1137 |
+
train_sentences += sentences['extra_dev']
|
1138 |
+
del writing_paths['extra_dev']
|
1139 |
+
with open(writing_paths['train'], 'w') as f:
|
1140 |
+
for sent in train_sentences:
|
1141 |
+
f.write('\t'.join(sent) + '\n')
|
1142 |
+
for split in ['dev', 'test']:
|
1143 |
+
if split in sentences:
|
1144 |
+
split_sentences = sentences[split]
|
1145 |
+
with open(writing_paths[split], 'w') as f:
|
1146 |
+
for sent in split_sentences:
|
1147 |
+
f.write('\t'.join(sent) + '\n')
|
1148 |
+
return writing_paths
|
1149 |
+
|
1150 |
+
def predict_ma_tag_of_modifier(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1151 |
+
if src_domain == tgt_domain:
|
1152 |
+
pred_paths = []
|
1153 |
+
if use_unlabeled_data:
|
1154 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
1155 |
+
file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
|
1156 |
+
|
1157 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
1158 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
1159 |
+
if use_labeled_data:
|
1160 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
1161 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
|
1162 |
+
|
1163 |
+
if not use_unlabeled_data and not use_labeled_data:
|
1164 |
+
raise ValueError
|
1165 |
+
else:
|
1166 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
1167 |
+
|
1168 |
+
gold_paths = []
|
1169 |
+
if use_labeled_data:
|
1170 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
1171 |
+
|
1172 |
+
if not use_unlabeled_data and not use_labeled_data:
|
1173 |
+
raise ValueError
|
1174 |
+
|
1175 |
+
paths = pred_paths + gold_paths
|
1176 |
+
print('############ Add Label Task #################')
|
1177 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
1178 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
|
1179 |
+
writing_paths = {}
|
1180 |
+
sentences = {}
|
1181 |
+
for path in paths:
|
1182 |
+
if tgt_domain in path:
|
1183 |
+
reading_path = parser_path + path
|
1184 |
+
writing_path = model_path + 'parser_' + path
|
1185 |
+
split = get_split(writing_path)
|
1186 |
+
else:
|
1187 |
+
reading_path = path
|
1188 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
1189 |
+
split = 'extra_train'
|
1190 |
+
writing_paths[split] = writing_path
|
1191 |
+
len_sent = 0
|
1192 |
+
lines = []
|
1193 |
+
sentences_list = []
|
1194 |
+
with open(reading_path, 'r') as file:
|
1195 |
+
for line in file:
|
1196 |
+
# line = line.decode('utf-8')
|
1197 |
+
line = line.strip()
|
1198 |
+
# Now blank space got detected
|
1199 |
+
if len(line) == 0:
|
1200 |
+
# Append next word to last column
|
1201 |
+
for idx in range(len_sent):
|
1202 |
+
lines[idx].append(clean_ma(lines[idx][3]))
|
1203 |
+
# Add root line first
|
1204 |
+
if len(lines) > 0:
|
1205 |
+
tmp_root_line = root_line + [root_line[3]]
|
1206 |
+
sentences_list.append(tmp_root_line)
|
1207 |
+
for line_ in lines:
|
1208 |
+
sentences_list.append(line_)
|
1209 |
+
sentences_list.append([])
|
1210 |
+
lines = []
|
1211 |
+
len_sent = 0
|
1212 |
+
continue
|
1213 |
+
tokens = line.split('\t')
|
1214 |
+
idx = tokens[0]
|
1215 |
+
word = tokens[1]
|
1216 |
+
pos = tokens[2]
|
1217 |
+
ner = tokens[3]
|
1218 |
+
head = tokens[4]
|
1219 |
+
arc_tag = tokens[5]
|
1220 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
1221 |
+
len_sent += 1
|
1222 |
+
sentences[split] = sentences_list
|
1223 |
+
|
1224 |
+
train_sentences = []
|
1225 |
+
if 'train' in sentences:
|
1226 |
+
train_sentences = sentences['train']
|
1227 |
+
else:
|
1228 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
1229 |
+
if 'extra_train' in sentences:
|
1230 |
+
train_sentences += sentences['extra_train']
|
1231 |
+
del writing_paths['extra_train']
|
1232 |
+
if 'extra_dev' in sentences:
|
1233 |
+
train_sentences += sentences['extra_dev']
|
1234 |
+
del writing_paths['extra_dev']
|
1235 |
+
with open(writing_paths['train'], 'w') as f:
|
1236 |
+
for sent in train_sentences:
|
1237 |
+
f.write('\t'.join(sent) + '\n')
|
1238 |
+
for split in ['dev', 'test']:
|
1239 |
+
if split in sentences:
|
1240 |
+
split_sentences = sentences[split]
|
1241 |
+
with open(writing_paths[split], 'w') as f:
|
1242 |
+
for sent in split_sentences:
|
1243 |
+
f.write('\t'.join(sent) + '\n')
|
1244 |
+
return writing_paths
|
1245 |
+
|
1246 |
+
def predict_coarse_of_modifier(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1247 |
+
if src_domain == tgt_domain:
|
1248 |
+
pred_paths = []
|
1249 |
+
if use_unlabeled_data:
|
1250 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
1251 |
+
file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
|
1252 |
+
|
1253 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
1254 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
1255 |
+
if use_labeled_data:
|
1256 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
1257 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
|
1258 |
+
|
1259 |
+
if not use_unlabeled_data and not use_labeled_data:
|
1260 |
+
raise ValueError
|
1261 |
+
else:
|
1262 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
1263 |
+
|
1264 |
+
gold_paths = []
|
1265 |
+
if use_labeled_data:
|
1266 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
1267 |
+
|
1268 |
+
if not use_unlabeled_data and not use_labeled_data:
|
1269 |
+
raise ValueError
|
1270 |
+
|
1271 |
+
paths = pred_paths + gold_paths
|
1272 |
+
print('############ Add Label Task #################')
|
1273 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
1274 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
|
1275 |
+
writing_paths = {}
|
1276 |
+
sentences = {}
|
1277 |
+
for path in paths:
|
1278 |
+
if tgt_domain in path:
|
1279 |
+
reading_path = parser_path + path
|
1280 |
+
writing_path = model_path + 'parser_' + path
|
1281 |
+
split = get_split(writing_path)
|
1282 |
+
else:
|
1283 |
+
reading_path = path
|
1284 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
1285 |
+
split = 'extra_train'
|
1286 |
+
writing_paths[split] = writing_path
|
1287 |
+
len_sent = 0
|
1288 |
+
lines = []
|
1289 |
+
sentences_list = []
|
1290 |
+
with open(reading_path, 'r') as file:
|
1291 |
+
for line in file:
|
1292 |
+
# line = line.decode('utf-8')
|
1293 |
+
line = line.strip()
|
1294 |
+
# Now blank space got detected
|
1295 |
+
if len(line) == 0:
|
1296 |
+
# Append next word to last column
|
1297 |
+
for idx in range(len_sent):
|
1298 |
+
lines[idx].append(lines[idx][3])
|
1299 |
+
# Add root line first
|
1300 |
+
if len(lines) > 0:
|
1301 |
+
tmp_root_line = root_line + [root_line[3]]
|
1302 |
+
sentences_list.append(tmp_root_line)
|
1303 |
+
for line_ in lines:
|
1304 |
+
sentences_list.append(line_)
|
1305 |
+
sentences_list.append([])
|
1306 |
+
lines = []
|
1307 |
+
len_sent = 0
|
1308 |
+
continue
|
1309 |
+
tokens = line.split('\t')
|
1310 |
+
idx = tokens[0]
|
1311 |
+
word = tokens[1]
|
1312 |
+
pos = tokens[2]
|
1313 |
+
ner = tokens[3]
|
1314 |
+
head = tokens[4]
|
1315 |
+
arc_tag = tokens[5]
|
1316 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
1317 |
+
len_sent += 1
|
1318 |
+
sentences[split] = sentences_list
|
1319 |
+
|
1320 |
+
train_sentences = []
|
1321 |
+
if 'train' in sentences:
|
1322 |
+
train_sentences = sentences['train']
|
1323 |
+
else:
|
1324 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
1325 |
+
if 'extra_train' in sentences:
|
1326 |
+
train_sentences += sentences['extra_train']
|
1327 |
+
del writing_paths['extra_train']
|
1328 |
+
if 'extra_dev' in sentences:
|
1329 |
+
train_sentences += sentences['extra_dev']
|
1330 |
+
del writing_paths['extra_dev']
|
1331 |
+
with open(writing_paths['train'], 'w') as f:
|
1332 |
+
for sent in train_sentences:
|
1333 |
+
f.write('\t'.join(sent) + '\n')
|
1334 |
+
for split in ['dev', 'test']:
|
1335 |
+
if split in sentences:
|
1336 |
+
split_sentences = sentences[split]
|
1337 |
+
with open(writing_paths[split], 'w') as f:
|
1338 |
+
for sent in split_sentences:
|
1339 |
+
f.write('\t'.join(sent) + '\n')
|
1340 |
+
return writing_paths
|
1341 |
+
import re
|
1342 |
+
def get_case(ma):
|
1343 |
+
indeclinable = ['ind','prep','interj','prep','conj','part']
|
1344 |
+
case_list = ['nom','voc','acc','i','inst','dat','abl','g','loc']
|
1345 |
+
gender_list = ['n','f','m','*']
|
1346 |
+
person_list = ['1','2','3']
|
1347 |
+
no_list = ['du','sg','pl']
|
1348 |
+
pops = [' ac',' ps']
|
1349 |
+
ma=ma.replace('sgpl','sg').replace('sgdu','sg')
|
1350 |
+
temp = re.sub("([\(\[]).*?([\)\]])", "\g<1>\g<2>", ma).replace('[] ','').strip(' []')
|
1351 |
+
temp = temp.split('.')
|
1352 |
+
if temp[-1] == '':
|
1353 |
+
temp.pop(-1)
|
1354 |
+
# Remove active passive
|
1355 |
+
case=''
|
1356 |
+
no=''
|
1357 |
+
person=''
|
1358 |
+
gender=''
|
1359 |
+
tense=''
|
1360 |
+
coarse=''
|
1361 |
+
for a,b in enumerate(temp):
|
1362 |
+
if b in pops:
|
1363 |
+
temp.pop(a)
|
1364 |
+
# Get gender
|
1365 |
+
for a,b in enumerate(temp):
|
1366 |
+
if b.strip() in gender_list:
|
1367 |
+
gender = b.strip()
|
1368 |
+
temp.pop(a)
|
1369 |
+
# Get case
|
1370 |
+
for a,b in enumerate(temp):
|
1371 |
+
if b.strip() in case_list:
|
1372 |
+
case = b.strip()
|
1373 |
+
temp.pop(a)
|
1374 |
+
if case!= '':
|
1375 |
+
coarse ='Noun'
|
1376 |
+
# Get person
|
1377 |
+
for a,b in enumerate(temp):
|
1378 |
+
if b.strip() in person_list:
|
1379 |
+
person = b.strip()
|
1380 |
+
temp.pop(a)
|
1381 |
+
# Get no
|
1382 |
+
for a,b in enumerate(temp):
|
1383 |
+
if b.strip() in no_list:
|
1384 |
+
no = b.strip()
|
1385 |
+
temp.pop(a)
|
1386 |
+
# Get Tense
|
1387 |
+
for b in temp:
|
1388 |
+
tense=tense+ ' '+b.strip()
|
1389 |
+
tense=tense.strip()
|
1390 |
+
|
1391 |
+
# print(tense)
|
1392 |
+
if tense == 'adv':
|
1393 |
+
coarse = 'adv'
|
1394 |
+
for ind in indeclinable:
|
1395 |
+
if tense == ind:
|
1396 |
+
coarse = 'Ind'
|
1397 |
+
if tense == 'abs' or tense == 'ca abs':
|
1398 |
+
coarse = 'IV'
|
1399 |
+
if tense!='' and coarse=='':
|
1400 |
+
if person !='' or no!='':
|
1401 |
+
coarse= 'FV'
|
1402 |
+
else:
|
1403 |
+
coarse = 'IV'
|
1404 |
+
if case == 'i':
|
1405 |
+
return 'inst'
|
1406 |
+
|
1407 |
+
if case !='':
|
1408 |
+
return case
|
1409 |
+
else:
|
1410 |
+
return coarse
|
1411 |
+
def clean_ma(ma):
|
1412 |
+
ma = re.sub("([\(\[]).*?([\)\]])", "\g<1>\g<2>", ma).replace('[] ','').strip(' []').replace(' ac','').replace(' ps','').replace('sgpl','sg').replace('sgdu','sg')
|
1413 |
+
ma = ma.replace('i.','inst.').replace('.','').replace(' ','')
|
1414 |
+
return ma
|
1415 |
+
def predict_case_of_modifier(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1416 |
+
if src_domain == tgt_domain:
|
1417 |
+
pred_paths = []
|
1418 |
+
if use_unlabeled_data:
|
1419 |
+
pred_paths = [file for file in os.listdir(parser_path) if
|
1420 |
+
file.endswith("pred.txt") and 'extra' in file and tgt_domain in file]
|
1421 |
+
|
1422 |
+
gold_paths = [file for file in os.listdir(parser_path) if
|
1423 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' not in file]
|
1424 |
+
if use_labeled_data:
|
1425 |
+
gold_paths += [file for file in os.listdir(parser_path) if
|
1426 |
+
file.endswith("gold.txt") and 'extra' not in file and tgt_domain in file and 'train' in file]
|
1427 |
+
|
1428 |
+
if not use_unlabeled_data and not use_labeled_data:
|
1429 |
+
raise ValueError
|
1430 |
+
else:
|
1431 |
+
pred_paths = [file for file in os.listdir(parser_path) if file.endswith("pred.txt") and tgt_domain in file]
|
1432 |
+
|
1433 |
+
gold_paths = []
|
1434 |
+
if use_labeled_data:
|
1435 |
+
gold_paths = ['data/onto_pos_ner_dp_train_' + src_domain]
|
1436 |
+
|
1437 |
+
if not use_unlabeled_data and not use_labeled_data:
|
1438 |
+
raise ValueError
|
1439 |
+
|
1440 |
+
paths = pred_paths + gold_paths
|
1441 |
+
print('############ Add Label Task #################')
|
1442 |
+
print("Adding labels to paths: %s" % ', '.join(paths))
|
1443 |
+
root_line = ['0', ROOT, 'XX', 'O', '0', 'root']
|
1444 |
+
writing_paths = {}
|
1445 |
+
sentences = {}
|
1446 |
+
for path in paths:
|
1447 |
+
if tgt_domain in path:
|
1448 |
+
reading_path = parser_path + path
|
1449 |
+
writing_path = model_path + 'parser_' + path
|
1450 |
+
split = get_split(writing_path)
|
1451 |
+
else:
|
1452 |
+
reading_path = path
|
1453 |
+
writing_path = model_path + 'parser_' + 'domain_' + src_domain + '_train_model_domain_' + src_domain + '_data_domain_' + src_domain + '_gold.txt'
|
1454 |
+
split = 'extra_train'
|
1455 |
+
writing_paths[split] = writing_path
|
1456 |
+
len_sent = 0
|
1457 |
+
lines = []
|
1458 |
+
sentences_list = []
|
1459 |
+
with open(reading_path, 'r') as file:
|
1460 |
+
for line in file:
|
1461 |
+
# line = line.decode('utf-8')
|
1462 |
+
line = line.strip()
|
1463 |
+
# Now blank space got detected
|
1464 |
+
if len(line) == 0:
|
1465 |
+
# Append next word to last column
|
1466 |
+
for idx in range(len_sent):
|
1467 |
+
lines[idx].append(get_case(lines[idx][3]))
|
1468 |
+
# Add root line first
|
1469 |
+
if len(lines) > 0:
|
1470 |
+
tmp_root_line = root_line + [root_line[3]]
|
1471 |
+
sentences_list.append(tmp_root_line)
|
1472 |
+
for line_ in lines:
|
1473 |
+
sentences_list.append(line_)
|
1474 |
+
sentences_list.append([])
|
1475 |
+
lines = []
|
1476 |
+
len_sent = 0
|
1477 |
+
continue
|
1478 |
+
tokens = line.split('\t')
|
1479 |
+
idx = tokens[0]
|
1480 |
+
word = tokens[1]
|
1481 |
+
pos = tokens[2]
|
1482 |
+
ner = tokens[3]
|
1483 |
+
head = tokens[4]
|
1484 |
+
arc_tag = tokens[5]
|
1485 |
+
lines.append([idx, word, pos, ner, head, arc_tag])
|
1486 |
+
len_sent += 1
|
1487 |
+
sentences[split] = sentences_list
|
1488 |
+
|
1489 |
+
train_sentences = []
|
1490 |
+
if 'train' in sentences:
|
1491 |
+
train_sentences = sentences['train']
|
1492 |
+
else:
|
1493 |
+
writing_paths['train'] = writing_paths['extra_train'].replace('extra_train', 'train')
|
1494 |
+
if 'extra_train' in sentences:
|
1495 |
+
train_sentences += sentences['extra_train']
|
1496 |
+
del writing_paths['extra_train']
|
1497 |
+
if 'extra_dev' in sentences:
|
1498 |
+
train_sentences += sentences['extra_dev']
|
1499 |
+
del writing_paths['extra_dev']
|
1500 |
+
with open(writing_paths['train'], 'w') as f:
|
1501 |
+
for sent in train_sentences:
|
1502 |
+
f.write('\t'.join(sent) + '\n')
|
1503 |
+
for split in ['dev', 'test']:
|
1504 |
+
if split in sentences:
|
1505 |
+
split_sentences = sentences[split]
|
1506 |
+
with open(writing_paths[split], 'w') as f:
|
1507 |
+
for sent in split_sentences:
|
1508 |
+
f.write('\t'.join(sent) + '\n')
|
1509 |
+
return writing_paths
|
1510 |
+
|
1511 |
+
def Multitask_case_predict(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1512 |
+
writing_paths = {}
|
1513 |
+
# multitask_silver_20ktrain_san
|
1514 |
+
writing_paths['train'] = 'data/ud_pos_ner_dp_train_san_case'
|
1515 |
+
writing_paths['dev'] = 'data/ud_pos_ner_dp_dev_san_case'
|
1516 |
+
writing_paths['test'] = 'data/ud_pos_ner_dp_test_san_case'
|
1517 |
+
return writing_paths
|
1518 |
+
|
1519 |
+
def Multitask_POS_predict(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1520 |
+
writing_paths = {}
|
1521 |
+
# multitask_silver_20ktrain_san
|
1522 |
+
# writing_paths['train'] = 'data/Multitask_POS_predict_train_san'
|
1523 |
+
# writing_paths['dev'] = 'data/Multitask_POS_predict_dev_san'
|
1524 |
+
# writing_paths['test'] = 'data/Multitask_POS_predict_test_san'
|
1525 |
+
writing_paths['train'] = 'data/ud_pos_ner_dp_train_san_POS'
|
1526 |
+
writing_paths['dev'] = 'data/ud_pos_ner_dp_dev_san_POS'
|
1527 |
+
writing_paths['test'] = 'data/ud_pos_ner_dp_test_san_POS'
|
1528 |
+
return writing_paths
|
1529 |
+
|
1530 |
+
def Multitask_coarse_predict(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1531 |
+
writing_paths = {}
|
1532 |
+
# multitask_silver_20ktrain_san
|
1533 |
+
writing_paths['train'] = 'data/Multitask_coarse_predict_train_san'
|
1534 |
+
writing_paths['dev'] = 'data/Multitask_coarse_predict_dev_san'
|
1535 |
+
writing_paths['test'] = 'data/Multitask_coarse_predict_test_san'
|
1536 |
+
return writing_paths
|
1537 |
+
|
1538 |
+
def Multitask_label_predict(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1539 |
+
writing_paths = {}
|
1540 |
+
# multitask_silver_20ktrain_san
|
1541 |
+
writing_paths['train'] = 'data/Multitask_label_predict_train_san'
|
1542 |
+
writing_paths['dev'] = 'data/Multitask_label_predict_dev_san'
|
1543 |
+
writing_paths['test'] = 'data/Multitask_label_predict_test_san'
|
1544 |
+
return writing_paths
|
1545 |
+
|
1546 |
+
|
1547 |
+
def MRL_case(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1548 |
+
writing_paths = {}
|
1549 |
+
# multitask_silver_20ktrain_san
|
1550 |
+
writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_case'
|
1551 |
+
writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_case'
|
1552 |
+
writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_case'
|
1553 |
+
return writing_paths
|
1554 |
+
|
1555 |
+
def MRL_POS(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1556 |
+
writing_paths = {}
|
1557 |
+
# multitask_silver_20ktrain_san
|
1558 |
+
writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_POS'
|
1559 |
+
writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_POS'
|
1560 |
+
writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_POS'
|
1561 |
+
return writing_paths
|
1562 |
+
|
1563 |
+
def MRL_label(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1564 |
+
writing_paths = {}
|
1565 |
+
# multitask_silver_20ktrain_san
|
1566 |
+
writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_dep'
|
1567 |
+
writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_dep'
|
1568 |
+
writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_dep'
|
1569 |
+
return writing_paths
|
1570 |
+
|
1571 |
+
def MRL_no(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1572 |
+
writing_paths = {}
|
1573 |
+
# multitask_silver_20ktrain_san
|
1574 |
+
writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_no'
|
1575 |
+
writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_no'
|
1576 |
+
writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_no'
|
1577 |
+
return writing_paths
|
1578 |
+
|
1579 |
+
def MRL_Person(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1580 |
+
writing_paths = {}
|
1581 |
+
# multitask_silver_20ktrain_san
|
1582 |
+
writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_per'
|
1583 |
+
writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_per'
|
1584 |
+
writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_per'
|
1585 |
+
return writing_paths
|
1586 |
+
def MRL_Gender(model_path, parser_path, src_domain, tgt_domain, use_unlabeled_data=True, use_labeled_data=True):
|
1587 |
+
writing_paths = {}
|
1588 |
+
# multitask_silver_20ktrain_san
|
1589 |
+
writing_paths['train'] = 'data/Prep_MRL/ud_pos_ner_dp_train_'+src_domain+'_gen'
|
1590 |
+
writing_paths['dev'] = 'data/Prep_MRL/ud_pos_ner_dp_dev_'+src_domain+'_gen'
|
1591 |
+
writing_paths['test'] = 'data/Prep_MRL/ud_pos_ner_dp_test_'+src_domain+'_gen'
|
1592 |
+
return writing_paths
|
utils/io_/writer.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class Writer(object):
|
3 |
+
def __init__(self, alphabets):
|
4 |
+
self.__source_file = None
|
5 |
+
self.alphabets = alphabets
|
6 |
+
|
7 |
+
def start(self, file_path):
|
8 |
+
self.__source_file = open(file_path, 'w')
|
9 |
+
|
10 |
+
def close(self):
|
11 |
+
self.__source_file.close()
|
12 |
+
|
13 |
+
def write(self, word, pos, ner, head, arc, lengths, auto_label=None, symbolic_root=False, symbolic_end=False):
|
14 |
+
batch_size, _ = word.shape
|
15 |
+
start = 1 if symbolic_root else 0
|
16 |
+
end = 1 if symbolic_end else 0
|
17 |
+
for i in range(batch_size):
|
18 |
+
for j in range(start, lengths[i] - end):
|
19 |
+
w = self.alphabets['word_alphabet'].get_instance(word[i, j])
|
20 |
+
p = self.alphabets['pos_alphabet'].get_instance(pos[i, j])
|
21 |
+
n = self.alphabets['ner_alphabet'].get_instance(ner[i, j])
|
22 |
+
t = self.alphabets['arc_alphabet'].get_instance(arc[i, j])
|
23 |
+
h = head[i, j]
|
24 |
+
if auto_label is not None:
|
25 |
+
m = self.alphabets['auto_label_alphabet'].get_instance(auto_label[i, j])
|
26 |
+
self.__source_file.write('%d\t%s\t%s\t%s\t%d\t%s\t%s\n' % (j, w, p, n, h, t, m))
|
27 |
+
else:
|
28 |
+
self.__source_file.write('%d\t%s\t%s\t%s\t%d\t%s\n' % (j, w, p, n, h, t))
|
29 |
+
self.__source_file.write('\n')
|
30 |
+
|
31 |
+
class Index2Instance(object):
|
32 |
+
def __init__(self, alphabet):
|
33 |
+
self.__alphabet = alphabet
|
34 |
+
|
35 |
+
def index2instance(self, indices, lengths, symbolic_root=False, symbolic_end=False):
|
36 |
+
batch_size, _ = indices.shape
|
37 |
+
start = 1 if symbolic_root else 0
|
38 |
+
end = 1 if symbolic_end else 0
|
39 |
+
instnaces = []
|
40 |
+
for i in range(batch_size):
|
41 |
+
tmp_instances = []
|
42 |
+
for j in range(start, lengths[i] - end):
|
43 |
+
instamce = self.__alphabet.get_instance(indices[i, j])
|
44 |
+
tmp_instances.append(instamce)
|
45 |
+
instnaces.append(tmp_instances)
|
46 |
+
return instnaces
|
utils/load_word_embeddings.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import numpy as np
|
3 |
+
from gensim.models import KeyedVectors
|
4 |
+
import gzip
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
|
8 |
+
def calc_mean_vec_for_lower_mapping(embedd_dict):
|
9 |
+
lower_counts = {}
|
10 |
+
for word in embedd_dict:
|
11 |
+
word_lower = word.lower()
|
12 |
+
if word_lower not in lower_counts:
|
13 |
+
lower_counts[word_lower] = [word]
|
14 |
+
else:
|
15 |
+
lower_counts[word_lower] = lower_counts[word_lower] + [word]
|
16 |
+
# calculating mean vector for all words that have the same mapping after performing lower()
|
17 |
+
for word in lower_counts:
|
18 |
+
embedd_dict[word] = np.mean([embedd_dict[word_] for word_ in lower_counts[word]])
|
19 |
+
return embedd_dict
|
20 |
+
|
21 |
+
def load_embedding_dict(embedding, embedding_path, lower_case=False):
|
22 |
+
"""
|
23 |
+
load word embeddings from file
|
24 |
+
:param embedding:
|
25 |
+
:param embedding_path:
|
26 |
+
:return: embedding dict, embedding dimention, caseless
|
27 |
+
"""
|
28 |
+
print("loading embedding: %s from %s" % (embedding, embedding_path))
|
29 |
+
if lower_case:
|
30 |
+
pkl_path = embedding_path + '_lower' + '.pkl'
|
31 |
+
else:
|
32 |
+
pkl_path = embedding_path + '.pkl'
|
33 |
+
if os.path.isfile(pkl_path):
|
34 |
+
# load dict and dim from a pickle file
|
35 |
+
with open(pkl_path, 'rb') as f:
|
36 |
+
embedd_dict, embedd_dim = pickle.load(f)
|
37 |
+
print("num dimensions of word embeddings:", embedd_dim)
|
38 |
+
return embedd_dict, embedd_dim
|
39 |
+
|
40 |
+
if embedding == 'glove':
|
41 |
+
# loading GloVe
|
42 |
+
embedd_dict = {}
|
43 |
+
word = None
|
44 |
+
with io.open(embedding_path, 'r', encoding='utf-8') as f:
|
45 |
+
for line in f:
|
46 |
+
word, vec = line.split(' ', 1)
|
47 |
+
embedd_dict[word] = np.fromstring(vec, sep=' ')
|
48 |
+
embedd_dim = len(embedd_dict[word])
|
49 |
+
if lower_case:
|
50 |
+
embedd_dict = calc_mean_vec_for_lower_mapping(embedd_dict)
|
51 |
+
for k, v in embedd_dict.items():
|
52 |
+
if len(v) != embedd_dim:
|
53 |
+
print(len(v),embedd_dim)
|
54 |
+
|
55 |
+
elif embedding == 'fasttext':
|
56 |
+
# loading GloVe
|
57 |
+
embedd_dict = {}
|
58 |
+
word = None
|
59 |
+
with io.open(embedding_path, 'r', encoding='utf-8') as f:
|
60 |
+
# skip first line
|
61 |
+
for i, line in enumerate(f):
|
62 |
+
if i == 0:
|
63 |
+
continue
|
64 |
+
word, vec = line.split(' ', 1)
|
65 |
+
embedd_dict[word] = np.fromstring(vec, sep=' ')
|
66 |
+
embedd_dim = len(embedd_dict[word])
|
67 |
+
if lower_case:
|
68 |
+
embedd_dict = calc_mean_vec_for_lower_mapping(embedd_dict)
|
69 |
+
for k, v in embedd_dict.items():
|
70 |
+
if len(v) != embedd_dim:
|
71 |
+
print(len(v),embedd_dim)
|
72 |
+
|
73 |
+
elif embedding == 'hellwig':
|
74 |
+
# loading hellwig
|
75 |
+
embedd_dict = {}
|
76 |
+
word = None
|
77 |
+
with io.open(embedding_path, 'r', encoding='utf-8') as f:
|
78 |
+
# skip first line
|
79 |
+
for i, line in enumerate(f):
|
80 |
+
if i == 0:
|
81 |
+
continue
|
82 |
+
word, vec = line.split(' ', 1)
|
83 |
+
embedd_dict[word] = np.fromstring(vec, sep=' ')
|
84 |
+
embedd_dim = len(embedd_dict[word])
|
85 |
+
if lower_case:
|
86 |
+
embedd_dict = calc_mean_vec_for_lower_mapping(embedd_dict)
|
87 |
+
for k, v in embedd_dict.items():
|
88 |
+
if len(v) != embedd_dim:
|
89 |
+
print(len(v),embedd_dim)
|
90 |
+
|
91 |
+
elif embedding == 'one_hot':
|
92 |
+
# loading hellwig
|
93 |
+
embedd_dict = {}
|
94 |
+
word = None
|
95 |
+
with io.open(embedding_path, 'r', encoding='utf-8') as f:
|
96 |
+
# skip first line
|
97 |
+
for i, line in enumerate(f):
|
98 |
+
if i == 0:
|
99 |
+
continue
|
100 |
+
word, vec = line.split('@', 1)
|
101 |
+
embedd_dict[word] = np.fromstring(vec, sep=' ')
|
102 |
+
embedd_dim = len(embedd_dict[word])
|
103 |
+
if lower_case:
|
104 |
+
embedd_dict = calc_mean_vec_for_lower_mapping(embedd_dict)
|
105 |
+
for k, v in embedd_dict.items():
|
106 |
+
if len(v) != embedd_dim:
|
107 |
+
print(len(v),embedd_dim)
|
108 |
+
|
109 |
+
elif embedding == 'word2vec':
|
110 |
+
# loading word2vec
|
111 |
+
embedd_dict = KeyedVectors.load_word2vec_format(embedding_path, binary=True)
|
112 |
+
if lower_case:
|
113 |
+
embedd_dict = calc_mean_vec_for_lower_mapping(embedd_dict)
|
114 |
+
embedd_dim = embedd_dict.vector_size
|
115 |
+
|
116 |
+
else:
|
117 |
+
raise ValueError("embedding should choose from [fasttext, glove, word2vec]")
|
118 |
+
|
119 |
+
print("num dimensions of word embeddings:", embedd_dim)
|
120 |
+
# save dict and dim to a pickle file
|
121 |
+
with open(pkl_path, 'wb') as f:
|
122 |
+
pickle.dump([embedd_dict, embedd_dim], f, pickle.HIGHEST_PROTOCOL)
|
123 |
+
return embedd_dict, embedd_dim
|
utils/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .parsing import *
|
2 |
+
from .parsing_gating import *
|
3 |
+
from .sequence_tagger import *
|
utils/models/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (190 Bytes). View file
|
|
utils/models/__pycache__/parsing.cpython-37.pyc
ADDED
Binary file (10.5 kB). View file
|
|
utils/models/__pycache__/parsing_gating.cpython-37.pyc
ADDED
Binary file (13.7 kB). View file
|
|