source: ThirdParty/mpqc_open/src/lib/math/scmat/distdiag.cc

Candidate_v1.6.1
Last change on this file was 860145, checked in by Frederik Heber <heber@…>, 8 years ago

Merge commit '0b990dfaa8c6007a996d030163a25f7f5fc8a7e7' as 'ThirdParty/mpqc_open'

  • Property mode set to 100644
File size: 9.0 KB
Line 
1//
2// distdiag.cc
3//
4// Copyright (C) 1996 Limit Point Systems, Inc.
5//
6// Author: Curtis Janssen <cljanss@limitpt.com>
7// Maintainer: LPS
8//
9// This file is part of the SC Toolkit.
10//
11// The SC Toolkit is free software; you can redistribute it and/or modify
12// it under the terms of the GNU Library General Public License as published by
13// the Free Software Foundation; either version 2, or (at your option)
14// any later version.
15//
16// The SC Toolkit is distributed in the hope that it will be useful,
17// but WITHOUT ANY WARRANTY; without even the implied warranty of
18// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19// GNU Library General Public License for more details.
20//
21// You should have received a copy of the GNU Library General Public License
22// along with the SC Toolkit; see the file COPYING.LIB. If not, write to
23// the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.
24//
25// The U.S. Government is granted a limited license as per AL 91-7.
26//
27
28#include <iostream>
29#include <math.h>
30
31#include <util/misc/formio.h>
32#include <util/keyval/keyval.h>
33#include <math/scmat/dist.h>
34#include <math/scmat/cmatrix.h>
35#include <math/scmat/elemop.h>
36
37using namespace std;
38using namespace sc;
39
40#define DEBUG 0
41
42/////////////////////////////////////////////////////////////////////////////
43// DistDiagSCMatrix member functions
44
45static ClassDesc DistDiagSCMatrix_cd(
46 typeid(DistDiagSCMatrix),"DistDiagSCMatrix",1,"public DiagSCMatrix",
47 0, 0, 0);
48
49DistDiagSCMatrix::DistDiagSCMatrix(const RefSCDimension&a,DistSCMatrixKit*k):
50 DiagSCMatrix(a,k)
51{
52 init_blocklist();
53}
54
55int
56DistDiagSCMatrix::block_to_node(int i) const
57{
58 return (i)%messagegrp()->n();
59}
60
61Ref<SCMatrixBlock>
62DistDiagSCMatrix::block_to_block(int i) const
63{
64 int offset = i;
65 int nproc = messagegrp()->n();
66
67 if ((offset%nproc) != messagegrp()->me()) return 0;
68
69 SCMatrixBlockListIter I;
70 for (I=blocklist->begin(); I!=blocklist->end(); I++) {
71 if (I.block()->blocki == i)
72 return I.block();
73 }
74
75 ExEnv::errn() << indent
76 << "DistDiagSCMatrix::block_to_block: internal error" << endl;
77 abort();
78 return 0;
79}
80
81double *
82DistDiagSCMatrix::find_element(int i) const
83{
84 int bi, oi;
85 d->blocks()->elem_to_block(i, bi, oi);
86
87 if (DEBUG)
88 ExEnv::outn() << messagegrp()->me() << ": "
89 << "find_element(" << i << "): "
90 << "block = " << bi << ", "
91 << "offset = " << oi
92 << endl;
93
94 Ref<SCMatrixDiagBlock> blk; blk << block_to_block(bi);
95 if (blk.nonnull()) {
96 if (DEBUG)
97 ExEnv::outn() << messagegrp()->me() << ": ndat = " << blk->ndat() << endl;
98 if (oi >= blk->ndat()) {
99 ExEnv::errn() << messagegrp()->me() << ": DistDiagSCMatrix::find_element"
100 << ": internal error" << endl;
101 abort();
102 }
103 return &blk->dat()[oi];
104 }
105 else {
106 if (DEBUG)
107 ExEnv::outn() << messagegrp()->me() << ": can't find" << endl;
108 return 0;
109 }
110}
111
112int
113DistDiagSCMatrix::element_to_node(int i) const
114{
115 int bi, oi;
116 d->blocks()->elem_to_block(i, bi, oi);
117
118 return block_to_node(bi);
119}
120
121void
122DistDiagSCMatrix::init_blocklist()
123{
124 int i;
125 int nproc = messagegrp()->n();
126 int me = messagegrp()->me();
127 blocklist = new SCMatrixBlockList;
128 SCMatrixBlock *b;
129 for (i=0; i<d->blocks()->nblock(); i++) {
130 if (i%nproc != me) continue;
131 b = new SCMatrixDiagBlock(d->blocks()->start(i),
132 d->blocks()->fence(i),
133 d->blocks()->start(i));
134 b->blocki = i;
135 b->blockj = i;
136 blocklist->insert(b);
137 }
138}
139
140DistDiagSCMatrix::~DistDiagSCMatrix()
141{
142}
143
144double
145DistDiagSCMatrix::get_element(int i) const
146{
147 double res;
148 double *e = find_element(i);
149 if (e) {
150 res = *e;
151 messagegrp()->bcast(res, messagegrp()->me());
152 }
153 else {
154 messagegrp()->bcast(res, element_to_node(i));
155 }
156 return res;
157}
158
159void
160DistDiagSCMatrix::set_element(int i,double a)
161{
162 double *e = find_element(i);
163 if (e) {
164 *e = a;
165 }
166}
167
168void
169DistDiagSCMatrix::accumulate_element(int i,double a)
170{
171 double *e = find_element(i);
172 if (e) {
173 *e += a;
174 }
175}
176
177void
178DistDiagSCMatrix::accumulate(const DiagSCMatrix*a)
179{
180 // make sure that the argument is of the correct type
181 const DistDiagSCMatrix* la
182 = require_dynamic_cast<const DistDiagSCMatrix*>(a,"DistDiagSCMatrix::accumulate");
183
184 // make sure that the dimensions match
185 if (!dim()->equiv(la->dim())) {
186 ExEnv::errn() << indent << "DistDiagSCMatrix::accumulate(SCMatrix*a): "
187 << "dimensions don't match\n";
188 abort();
189 }
190
191 SCMatrixBlockListIter i1, i2;
192 for (i1=la->blocklist->begin(),i2=blocklist->begin();
193 i1!=la->blocklist->end() && i2!=blocklist->end();
194 i1++,i2++) {
195 int n = i1.block()->ndat();
196 if (n != i2.block()->ndat()) {
197 ExEnv::errn() << indent << "DistDiagSCMatrix::accumulate "
198 << "mismatch: internal error" << endl;
199 abort();
200 }
201 double *dat1 = i1.block()->dat();
202 double *dat2 = i2.block()->dat();
203 for (int i=0; i<n; i++) {
204 dat2[i] += dat1[i];
205 }
206 }
207}
208
209double
210DistDiagSCMatrix::invert_this()
211{
212 Ref<SCMatrixSubblockIter> I = local_blocks(SCMatrixSubblockIter::Read);
213 double det = 1.0;
214 for (I->begin(); I->ready(); I->next()) {
215 int n = I->block()->ndat();
216 double *data = I->block()->dat();
217 for (int i=0; i<n; i++) {
218 det *= data[i];
219 data[i] = 1.0/data[i];
220 }
221 }
222 GrpProductReduce<double> gred;
223 messagegrp()->reduce(&det, 1, gred);
224 return det;
225}
226
227double
228DistDiagSCMatrix::determ_this()
229{
230 Ref<SCMatrixSubblockIter> I = local_blocks(SCMatrixSubblockIter::Read);
231 double det = 1.0;
232 for (I->begin(); I->ready(); I->next()) {
233 int n = I->block()->ndat();
234 double *data = I->block()->dat();
235 for (int i=0; i<n; i++) {
236 det *= data[i];
237 }
238 }
239 GrpProductReduce<double> gred;
240 messagegrp()->reduce(det, gred);
241 return det;
242}
243
244double
245DistDiagSCMatrix::trace()
246{
247 double ret=0.0;
248 Ref<SCMatrixSubblockIter> I = local_blocks(SCMatrixSubblockIter::Read);
249 for (I->begin(); I->ready(); I->next()) {
250 int n = I->block()->ndat();
251 double *data = I->block()->dat();
252 for (int i=0; i<n; i++) {
253 ret += data[i];
254 }
255 }
256 messagegrp()->sum(ret);
257 return ret;
258}
259
260void
261DistDiagSCMatrix::gen_invert_this()
262{
263 Ref<SCMatrixSubblockIter> I = local_blocks(SCMatrixSubblockIter::Read);
264 for (I->begin(); I->ready(); I->next()) {
265 int n = I->block()->ndat();
266 double *data = I->block()->dat();
267 for (int i=0; i<n; i++) {
268 if (fabs(data[i]) > 1.0e-8)
269 data[i] = 1.0/data[i];
270 else
271 data[i] = 0.0;
272 }
273 }
274}
275
276void
277DistDiagSCMatrix::element_op(const Ref<SCElementOp>& op)
278{
279 SCMatrixBlockListIter i;
280 for (i = blocklist->begin(); i != blocklist->end(); i++) {
281 op->process_base(i.block());
282 }
283 if (op->has_collect()) op->collect(messagegrp());
284}
285
286void
287DistDiagSCMatrix::element_op(const Ref<SCElementOp2>& op,
288 DiagSCMatrix* m)
289{
290 DistDiagSCMatrix *lm
291 = require_dynamic_cast<DistDiagSCMatrix*>(m,"DistDiagSCMatrix::element_op");
292
293 if (!dim()->equiv(lm->dim())) {
294 ExEnv::errn() << indent << "DistDiagSCMatrix: bad element_op\n";
295 abort();
296 }
297 SCMatrixBlockListIter i, j;
298 for (i = blocklist->begin(), j = lm->blocklist->begin();
299 i != blocklist->end();
300 i++, j++) {
301 op->process_base(i.block(), j.block());
302 }
303 if (op->has_collect()) op->collect(messagegrp());
304}
305
306void
307DistDiagSCMatrix::element_op(const Ref<SCElementOp3>& op,
308 DiagSCMatrix* m,DiagSCMatrix* n)
309{
310 DistDiagSCMatrix *lm
311 = require_dynamic_cast<DistDiagSCMatrix*>(m,"DistDiagSCMatrix::element_op");
312 DistDiagSCMatrix *ln
313 = require_dynamic_cast<DistDiagSCMatrix*>(n,"DistDiagSCMatrix::element_op");
314
315 if (!dim()->equiv(lm->dim()) || !dim()->equiv(ln->dim())) {
316 ExEnv::errn() << indent << "DistDiagSCMatrix: bad element_op\n";
317 abort();
318 }
319 SCMatrixBlockListIter i, j, k;
320 for (i = blocklist->begin(),
321 j = lm->blocklist->begin(),
322 k = ln->blocklist->begin();
323 i != blocklist->end();
324 i++, j++, k++) {
325 op->process_base(i.block(), j.block(), k.block());
326 }
327 if (op->has_collect()) op->collect(messagegrp());
328}
329
330Ref<SCMatrixSubblockIter>
331DistDiagSCMatrix::local_blocks(SCMatrixSubblockIter::Access access)
332{
333 return new SCMatrixListSubblockIter(access, blocklist);
334}
335
336Ref<SCMatrixSubblockIter>
337DistDiagSCMatrix::all_blocks(SCMatrixSubblockIter::Access access)
338{
339 return new DistSCMatrixListSubblockIter(access, blocklist, messagegrp());
340}
341
342void
343DistDiagSCMatrix::error(const char *msg)
344{
345 ExEnv::errn() << indent << "DistDiagSCMatrix: error: " << msg << endl;
346}
347
348Ref<DistSCMatrixKit>
349DistDiagSCMatrix::skit()
350{
351 return dynamic_cast<DistSCMatrixKit*>(kit().pointer());
352}
353
354/////////////////////////////////////////////////////////////////////////////
355
356// Local Variables:
357// mode: c++
358// c-file-style: "CLJ"
359// End:
Note: See TracBrowser for help on using the repository browser.