summaryrefslogtreecommitdiff
path: root/array.c
blob: 7675144dd27f890e330598ac568f208a7bef35f7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <u.h>
#include <libc.h>
#include <bio.h>

#include "apl9.h"

int datasizes[] = {
	[AtypeInt] = sizeof(vlong),
	[AtypeArray] = sizeof(Array *)
};

Array *
mkscalarint(vlong i)
{
	Array *a = allocarray(AtypeInt, 0, 1);
	a->intdata[0] = i;

	return a;
}

Array *
duparray(Array *a)
{
	Array *b = allocarray(a->type, a->rank, a->size);
	memcpy(b->shape, a->shape, sizeof(int) * a->rank);
	memcpy(b->rawdata, a->rawdata, datasizes[a->type]*a->size);
	if(b->type == AtypeArray)
		for(int i = 0; i < b->size; i++)
			incref(b->arraydata[i]);
	return b;
}

int
simplearray(Array *a)
{
	return a->type != AtypeArray;
}

int
simplescalar(Array *a)
{
	return simplearray(a) && a->rank == 0;
}

Array *
extend(Array *a, Array *b)
{
	/* extend the singleton a to the shape of b */
	Array *shape = fnShape(b);
	Array *res = fnReshape(shape, a);
	freearray(shape);
	return res;
}

int
scalarextend(Array *a, Array *b, Array **aa, Array **bb)
{
	/* Extend the arrays a and b to have the same shape.
	   The resulting arrays are stored in aa and bb,
	   except when the ranks don't match or extension can't
	   happen, in which case the function returns 0 and
	   aa and bb are unchanged.
	*/

	if(a->size == 1 && b->size != 1){
		*aa = extend(a, b);
		*bb = fnSame(b);
	}else if(b->size == 1 && a->size != 1){
		*aa = fnSame(a);
		*bb = extend(b, a);
	}else if(a->size == b->size && a->rank == b->rank){
		/* Check that each dimension matches */
		for(int i = 0; i < a->rank; i++)
			if(a->shape[i] != b->shape[i])
				return 0;
		*aa = fnSame(a);
		*bb = fnSame(b);
	}else
		return 0;
	return 1;
}

Array *
arrayitem(Array *a, int index)
{
	Array *res = nil;
	switch(a->type){
	case AtypeInt:
		res = mkscalarint(a->intdata[index]);
		break;
	case AtypeArray:
		res = a->arraydata[index];
		incref(res);
		break;
	default:
		print("Unhandled case in arrayitem()\n");
		exits(nil);
	}
	return res;
}

Array *
simplifyarray(Array *a)
{
	/* simplify an array if possible. */
	if(a->type != AtypeArray || a->size == 0)
		return fnSame(a);
	int type = a->arraydata[0]->type;
	int i;
	for(i = 0; i < a->size; i++)
		if(a->arraydata[i]->type != type || a->arraydata[i]->rank != 0)
			return fnSame(a);
	Array *b = allocarray(type, a->rank, a->size);
	for(i = 0; i < a->rank; i++)
		b->shape[i] = a->shape[i];
	for(i = 0; i < a->size; i++)
		memcpy(b->rawdata + i * datasizes[type], a->arraydata[i]->rawdata, datasizes[type]);
	return b;
}